Beispiel #1
0
def train_single_thread(
        actor, critic, target_actor, target_critic, args, prepare_fn,
        global_episode, global_update_step, episodes_queue):
    workerseed = args.seed + 241 * args.thread
    set_global_seeds(workerseed)

    args.logdir = "{}/thread_{}".format(args.logdir, args.thread)
    create_if_need(args.logdir)

    _, update_fn, save_fn = prepare_fn(actor, critic, target_actor, target_critic, args)

    logger = Logger(args.logdir)

    buffer = create_buffer(args)

    if args.prioritized_replay:
        beta_deacy_fn = create_decay_fn(
            "linear",
            initial_value=args.prioritized_replay_beta0,
            final_value=1.0,
            max_step=args.max_update_steps)

    actor_learning_rate_decay_fn = create_decay_fn(
        "linear",
        initial_value=args.actor_lr,
        final_value=args.actor_lr_end,
        max_step=args.max_update_steps)
    critic_learning_rate_decay_fn = create_decay_fn(
        "linear",
        initial_value=args.critic_lr,
        final_value=args.critic_lr_end,
        max_step=args.max_update_steps)

    update_step = 0
    received_examples = 1  # just hack
    while global_episode.value < args.max_episodes * (args.num_threads - args.num_train_threads) \
            and global_update_step.value < args.max_update_steps * args.num_train_threads:
        actor_lr = actor_learning_rate_decay_fn(update_step)
        critic_lr = critic_learning_rate_decay_fn(update_step)

        actor_lr = min(args.actor_lr, max(args.actor_lr_end, actor_lr))
        critic_lr = min(args.critic_lr, max(args.critic_lr_end, critic_lr))

        while True:
            try:
                replay = episodes_queue.get_nowait()
                for (observation, action, reward, next_observation, done) in replay:
                    buffer.add(observation, action, reward, next_observation, done)
                received_examples += len(replay)
            except py_queue.Empty:
                break

        if len(buffer) >= args.train_steps:
            if args.prioritized_replay:
                beta = beta_deacy_fn(update_step)
                beta = min(1.0, max(args.prioritized_replay_beta0, beta))
                (tr_observations, tr_actions, tr_rewards, tr_next_observations, tr_dones,
                 weights, batch_idxes) = \
                    buffer.sample(
                        batch_size=args.batch_size,
                        beta=beta)
            else:
                (tr_observations, tr_actions, tr_rewards, tr_next_observations, tr_dones) = \
                    buffer.sample(batch_size=args.batch_size)
                weights, batch_idxes = np.ones_like(tr_rewards), None

            step_metrics, step_info = update_fn(
                tr_observations, tr_actions, tr_rewards,
                tr_next_observations, tr_dones,
                weights, actor_lr, critic_lr)

            update_step += 1
            global_update_step.value += 1

            if args.prioritized_replay:
                new_priorities = np.abs(step_info["td_error"]) + 1e-6
                buffer.update_priorities(batch_idxes, new_priorities)

            for key, value in step_metrics.items():
                value = to_numpy(value)[0]
                logger.scalar_summary(key, value, update_step)

            logger.scalar_summary("actor lr", actor_lr, update_step)
            logger.scalar_summary("critic lr", critic_lr, update_step)

            if update_step % args.save_step == 0:
                save_fn(update_step)
        else:
            time.sleep(1)

        logger.scalar_summary("buffer size", len(buffer), global_episode.value)
        logger.scalar_summary(
            "updates per example",
            update_step * args.batch_size / received_examples,
            global_episode.value)

    save_fn(update_step)

    raise KeyboardInterrupt
Beispiel #2
0
def play_single_thread(
        actor, critic, target_actor, target_critic, args, prepare_fn,
        global_episode, global_update_step, episodes_queue,
        best_reward):
    workerseed = args.seed + 241 * args.thread
    set_global_seeds(workerseed)

    args.logdir = "{}/thread_{}".format(args.logdir, args.thread)
    create_if_need(args.logdir)

    act_fn, _, save_fn = prepare_fn(actor, critic, target_actor, target_critic, args)

    logger = Logger(args.logdir)
    env = create_env(args)
    random_process = create_random_process(args)

    epsilon_cycle_len = random.randint(args.epsilon_cycle_len // 2, args.epsilon_cycle_len * 2)

    epsilon_decay_fn = create_decay_fn(
        "cycle",
        initial_value=args.initial_epsilon,
        final_value=args.final_epsilon,
        cycle_len=epsilon_cycle_len,
        num_cycles=args.max_episodes // epsilon_cycle_len)

    episode = 1
    step = 0
    start_time = time.time()
    while global_episode.value < args.max_episodes * (args.num_threads - args.num_train_threads) \
            and global_update_step.value < args.max_update_steps * args.num_train_threads:
        if episode % 100 == 0:
            env = create_env(args)
        seed = random.randrange(2 ** 32 - 2)

        epsilon = min(args.initial_epsilon, max(args.final_epsilon, epsilon_decay_fn(episode)))

        episode_metrics = {
            "reward": 0.0,
            "step": 0,
            "epsilon": epsilon
        }

        observation = env.reset(seed=seed, difficulty=args.difficulty)
        random_process.reset_states()
        done = False

        replay = []
        while not done:
            action = act_fn(observation, noise=epsilon * random_process.sample())
            next_observation, reward, done, _ = env.step(action)

            replay.append((observation, action, reward, next_observation, done))
            episode_metrics["reward"] += reward
            episode_metrics["step"] += 1

            observation = next_observation

        episodes_queue.put(replay)

        episode += 1
        global_episode.value += 1

        if episode_metrics["reward"] > best_reward.value:
            best_reward.value = episode_metrics["reward"]
            logger.scalar_summary("best reward", best_reward.value, episode)

            if episode_metrics["reward"] > 15.0 * args.reward_scale:
                save_fn(episode)

        step += episode_metrics["step"]
        elapsed_time = time.time() - start_time

        for key, value in episode_metrics.items():
            logger.scalar_summary(key, value, episode)
        logger.scalar_summary(
            "episode per minute",
            episode / elapsed_time * 60,
            episode)
        logger.scalar_summary(
            "step per second",
            step / elapsed_time,
            episode)

        if elapsed_time > 86400 * args.max_train_days:
            global_episode.value = args.max_episodes * (args.num_threads - args.num_train_threads) + 1

    raise KeyboardInterrupt
Beispiel #3
0
def train_multi_thread(actor, critic, target_actor, target_critic, args, prepare_fn, best_reward):
    workerseed = args.seed + 241 * args.thread
    set_global_seeds(workerseed)

    args.logdir = "{}/thread_{}".format(args.logdir, args.thread)
    create_if_need(args.logdir)

    act_fn, update_fn, save_fn = prepare_fn(actor, critic, target_actor, target_critic, args)
    logger = Logger(args.logdir)

    buffer = create_buffer(args)
    if args.prioritized_replay:
        beta_deacy_fn = create_decay_fn(
            "linear",
            initial_value=args.prioritized_replay_beta0,
            final_value=1.0,
            max_step=args.max_episodes)

    env = create_env(args)
    random_process = create_random_process(args)

    actor_learning_rate_decay_fn = create_decay_fn(
        "linear",
        initial_value=args.actor_lr,
        final_value=args.actor_lr_end,
        max_step=args.max_episodes)
    critic_learning_rate_decay_fn = create_decay_fn(
        "linear",
        initial_value=args.critic_lr,
        final_value=args.critic_lr_end,
        max_step=args.max_episodes)

    epsilon_cycle_len = random.randint(args.epsilon_cycle_len // 2, args.epsilon_cycle_len * 2)

    epsilon_decay_fn = create_decay_fn(
        "cycle",
        initial_value=args.initial_epsilon,
        final_value=args.final_epsilon,
        cycle_len=epsilon_cycle_len,
        num_cycles=args.max_episodes // epsilon_cycle_len)

    episode = 0
    step = 0
    start_time = time.time()
    while episode < args.max_episodes:
        if episode % 100 == 0:
            env = create_env(args)
        seed = random.randrange(2 ** 32 - 2)

        actor_lr = actor_learning_rate_decay_fn(episode)
        critic_lr = critic_learning_rate_decay_fn(episode)
        epsilon = min(args.initial_epsilon, max(args.final_epsilon, epsilon_decay_fn(episode)))

        episode_metrics = {
            "value_loss": 0.0,
            "policy_loss": 0.0,
            "reward": 0.0,
            "step": 0,
            "epsilon": epsilon
        }

        observation = env.reset(seed=seed, difficulty=args.difficulty)
        random_process.reset_states()
        done = False

        while not done:
            action = act_fn(observation, noise=epsilon*random_process.sample())
            next_observation, reward, done, _ = env.step(action)

            buffer.add(observation, action, reward, next_observation, done)
            episode_metrics["reward"] += reward
            episode_metrics["step"] += 1

            if len(buffer) >= args.train_steps:

                if args.prioritized_replay:
                    (tr_observations, tr_actions, tr_rewards, tr_next_observations, tr_dones,
                     weights, batch_idxes) = \
                        buffer.sample(batch_size=args.batch_size, beta=beta_deacy_fn(episode))
                else:
                    (tr_observations, tr_actions, tr_rewards, tr_next_observations, tr_dones) = \
                        buffer.sample(batch_size=args.batch_size)
                    weights, batch_idxes = np.ones_like(tr_rewards), None

                step_metrics, step_info = update_fn(
                    tr_observations, tr_actions, tr_rewards,
                    tr_next_observations, tr_dones,
                    weights, actor_lr, critic_lr)

                if args.prioritized_replay:
                    new_priorities = np.abs(step_info["td_error"]) + 1e-6
                    buffer.update_priorities(batch_idxes, new_priorities)

                for key, value in step_metrics.items():
                    value = to_numpy(value)[0]
                    episode_metrics[key] += value

            observation = next_observation

        episode += 1

        if episode_metrics["reward"] > 15.0 * args.reward_scale \
                and episode_metrics["reward"] > best_reward.value:
            best_reward.value = episode_metrics["reward"]
            logger.scalar_summary("best reward", best_reward.value, episode)
            save_fn(episode)

        step += episode_metrics["step"]
        elapsed_time = time.time() - start_time

        for key, value in episode_metrics.items():
            value = value if "loss" not in key else value / episode_metrics["step"]
            logger.scalar_summary(key, value, episode)
        logger.scalar_summary(
            "episode per minute",
            episode / elapsed_time * 60,
            episode)
        logger.scalar_summary(
            "step per second",
            step / elapsed_time,
            episode)
        logger.scalar_summary("actor lr", actor_lr, episode)
        logger.scalar_summary("critic lr", critic_lr, episode)

        if episode % args.save_step == 0:
            save_fn(episode)

        if elapsed_time > 86400 * args.max_train_days:
            episode = args.max_episodes + 1

    save_fn(episode)

    raise KeyboardInterrupt
class Trainer:
    """ Train and Validation with single GPU """
    def __init__(self, train_loader, val_loader, args):
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.args = args
        self.model = get_model(args)
        self.epochs = args.epochs
        self.total_step = len(train_loader) * args.epochs
        self.step = 0
        self.epoch = 0
        self.start_epoch = 1
        self.lr = args.learning_rate
        self.best_acc = 0

        # Log
        self.log_path = (
                PROJECT_ROOT / Path(SAVE_DIR) /
                Path(datetime.now().strftime("%Y%m%d%H%M%S") + "-")
                ).as_posix()
        self.log_path = Path(self.get_dirname(self.log_path, args))
        if not Path.exists(self.log_path):
            Path(self.log_path).mkdir(parents=True, exist_ok=True)
        self.logger = Logger("train", self.log_path, args.verbose)
        self.logger.log("Checkpoint files will be saved in {}".format(self.log_path))

        self.logger.add_level('STEP', 21, 'green')
        self.logger.add_level('EPOCH', 22, 'cyan')
        self.logger.add_level('EVAL', 23, 'yellow')

        self.criterion = nn.CrossEntropyLoss()
        if self.args.cuda:
            self.criterion = self.criterion.cuda()
        if args.half:
            self.model.half()
            self.criterion.half()

        params = self.model.parameters()
        self.optimizer = get_optimizer(args.optimizer, params, args)

    def train(self):
        self.eval()
        for self.epoch in range(self.start_epoch, self.args.epochs+1):
            self.adjust_learning_rate([int(self.args.epochs/2), int(self.args.epochs*3/4)], factor=0.1)
            self.train_epoch()
            self.eval()

        self.logger.writer.export_scalars_to_json(
            self.log_path.as_posix() + "/scalars-{}-{}-{}.json".format(
                self.args.model,
                self.args.seed,
                self.args.activation
            )
        )
        self.logger.writer.close()

    def train_epoch(self):
        self.model.train()
        eval_metrics = EvaluationMetrics(['Loss', 'Acc', 'Time'])

        for i, (images, labels) in enumerate(self.train_loader):
            st = time.time()
            self.step += 1
            images = torch.autograd.Variable(images)
            labels = torch.autograd.Variable(labels)
            if self.args.cuda:
                images = images.cuda()
                labels = labels.cuda()
            if self.args.half: images = images.half()

            outputs, loss = self.compute_loss(images, labels)

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            outputs = outputs.float()
            loss = loss.float()
            elapsed_time = time.time() - st

            _, preds = torch.max(outputs, 1)
            accuracy = (labels == preds.squeeze()).float().mean()

            batch_size = labels.size(0)
            eval_metrics.update('Loss', float(loss), batch_size)
            eval_metrics.update('Acc', float(accuracy), batch_size)
            eval_metrics.update('Time', elapsed_time, batch_size)

            if self.step % self.args.log_step == 0:
                self.logger.scalar_summary(eval_metrics.val, self.step, 'STEP')

        # Histogram of parameters
        for tag, p in self.model.named_parameters():
            tag = tag.split(".")
            tag = "Train_{}".format(tag[0]) + "/" + "/".join(tag[1:])
            try:
                self.logger.writer.add_histogram(tag, p.clone().cpu().data.numpy(), self.step)
                self.logger.writer.add_histogram(tag+'/grad', p.grad.clone().cpu().data.numpy(), self.step)
            except Exception as e:
                print("Check if variable {} is not used: {}".format(tag, e))

        self.logger.scalar_summary(eval_metrics.avg, self.step, 'EPOCH')


    def eval(self):
        self.model.eval()
        eval_metrics = EvaluationMetrics(['Loss', 'Acc', 'Time'])

        for i, (images, labels) in enumerate(self.val_loader):
            st = time.time()
            images = torch.autograd.Variable(images)
            labels = torch.autograd.Variable(labels)
            if self.args.cuda:
                images = images.cuda()
                labels = labels.cuda()
            if self.args.half: images = images.half()

            outputs, loss = self.compute_loss(images, labels)

            outputs = outputs.float()
            loss = loss.float()
            elapsed_time = time.time() - st

            _, preds = torch.max(outputs, 1)
            accuracy = (labels == preds.squeeze()).float().mean()

            batch_size = labels.size(0)
            eval_metrics.update('Loss', float(loss), batch_size)
            eval_metrics.update('Acc', float(accuracy), batch_size)
            eval_metrics.update('Time', elapsed_time, batch_size)

        # Save best model
        if eval_metrics.avg['Acc'] > self.best_acc:
            self.save()
            self.logger.log("Saving best model: epoch={}".format(self.epoch))
            self.best_acc = eval_metrics.avg['Acc']
            self.maybe_delete_old_pth(log_path=self.log_path.as_posix(), max_to_keep=1)

        self.logger.scalar_summary(eval_metrics.avg, self.step, 'EVAL')

    def get_dirname(self, path, args):
        path += "{}-".format(getattr(args, 'dataset'))
        path += "{}-".format(getattr(args, 'seed'))
        path += "{}".format(getattr(args, 'model'))
        return path

    def save(self, filename=None):
        if filename is None:
            filename = os.path.join(self.log_path, 'model-{}.pth'.format(self.epoch))
        torch.save({
            'model': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'epoch': self.start_epoch,
            'best_acc': self.best_acc,
            'args': self.args
        }, filename)

    def load(self, filename=None):
        if filename is None: filename = self.log_path
        S = torch.load(filename) if self.args.cuda else torch.load(filename, map_location=lambda storage, location: storage)
        self.model.load_state_dict(S['model'])
        self.optimizer.load_state_dict(S['optimizer'])
        self.epoch = S['epoch']
        self.best_acc = S['best_acc']
        self.args = S['args']

    def maybe_delete_old_pth(self, log_path, max_to_keep):
        """Model filename must end with xxx-xxx-[epoch].pth
        """
        # filename and time
        pths = [(f, int(f[:-4].split("-")[-1])) for f in os.listdir(log_path) if f.endswith('.pth')]
        if len(pths) > max_to_keep:
            sorted_pths = sorted(pths, key=lambda tup: tup[1])
            for i in range(len(pths) - max_to_keep):
                os.remove(os.path.join(log_path, sorted_pths[i][0]))

    def show_current_model(self):
        print("\n".join("{}: {}".format(k, v) for k, v in sorted(vars(self.args).items())))

        model_parameters = filter(lambda p: p.requires_grad, self.model.parameters())
        total_params = np.sum([np.prod(p.size()) for p in model_parameters])

        print('%s\n\n'%(type(self.model)))
        print('%s\n\n'%(inspect.getsource(self.model.__init__)))
        print('%s\n\n'%(inspect.getsource(self.model.forward)))

        # Total 95
        print("*"*40 + "%10s" % self.args.model + "*"*45)
        print("*"*40 + "PARAM INFO" + "*"*45)
        print("-"*95)
        print("| %40s | %25s | %20s |" % ("Param Name", "Shape", "Number of Params"))
        print("-"*95)
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                print("| %40s | %25s | %20d |" % (name, list(param.size()), np.prod(param.size())))
        print("-"*95)
        print("Total Params: %d" % (total_params))
        print("*"*95)

    def adjust_learning_rate(self, milestone, factor=0.1):
        if self.epoch in milestone:
            self.lr *= factor
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = self.lr

    def compute_loss(self, images, labels):
        outputs = self.model(images)
        loss = self.criterion(outputs, labels)
        return outputs, loss
class Defender(Trainer):
    """ Perform various adversarial attacks and defense on a pretrained model
    Scheme generates Tensor, not Variable
    """
    def __init__(self, val_loader, args, **kwargs):
        self.val_loader = val_loader
        self.args = args
        self.model = get_model(args)
        self.step = 0
        self.cuda = self.args.cuda

        self.log_path = (
            PROJECT_ROOT / Path("experiments") /
            Path(datetime.now().strftime("%Y%m%d%H%M%S") + "-")).as_posix()
        self.log_path = Path(self.get_dirname(self.log_path, args))
        if not Path.exists(self.log_path):
            Path(self.log_path).mkdir(parents=True, exist_ok=True)
        self.logger = Logger("defense", self.log_path, args.verbose)
        self.logger.log("Checkpoint files will be saved in {}".format(
            self.log_path))

        self.logger.add_level("ATTACK", 21, 'yellow')
        self.logger.add_level("DEFENSE", 22, 'cyan')
        self.logger.add_level("TEST", 23, 'white')
        self.logger.add_level("DIST", 11, 'white')

        self.kwargs = kwargs
        if args.domain_restrict:
            self.artifact = get_artifact(self.model, val_loader, args)
            self.kwargs['artifact'] = self.artifact

    def defend(self):
        self.model.eval()
        defense_scheme = getattr(defenses,
                                 self.args.defense)(self.model, self.args,
                                                    **self.kwargs)
        source = self.model
        if self.args.source is not None and (self.args.ckpt_name !=
                                             self.args.ckpt_src):
            target = self.args.ckpt_name
            self.args.model = self.args.source
            self.args.ckpt_name = self.args.ckpt_src
            source = get_model(self.args)
            self.logger.log("Transfer attack from {} -> {}".format(
                self.args.ckpt_src, target))
        attack_scheme = getattr(attacks, self.args.attack)(source, self.args,
                                                           **self.kwargs)

        eval_metrics = EvaluationMetrics(
            ['Test/Acc', 'Test/Top5', 'Test/Time'])
        eval_def_metrics = EvaluationMetrics(
            ['Def-Test/Acc', 'Def-Test/Top5', 'Def-Test/Time'])
        attack_metrics = EvaluationMetrics(
            ['Attack/Acc', 'Attack/Top5', 'Attack/Time'])
        defense_metrics = EvaluationMetrics(
            ['Defense/Acc', 'Defense/Top5', 'Defense/Time'])
        dist_metrics = EvaluationMetrics(['L0', 'L1', 'L2', 'Li'])

        for i, (images, labels) in enumerate(self.val_loader):
            self.step += 1
            if self.cuda:
                images = images.cuda()
                labels = labels.cuda()
            if self.args.half: images = images.half()

            # Inference
            st = time.time()
            outputs = self.model(self.to_var(images, self.cuda, True))
            outputs = outputs.float()
            _, preds = torch.topk(outputs, 5)

            acc = (labels == preds.data[:, 0]).float().mean()
            top5 = torch.sum(
                (labels.unsqueeze(1).repeat(1, 5) == preds.data).float(),
                dim=1).mean()
            eval_metrics.update('Test/Acc', float(acc), labels.size(0))
            eval_metrics.update('Test/Top5', float(top5), labels.size(0))
            eval_metrics.update('Test/Time', time.time() - st, labels.size(0))

            # Attacker
            st = time.time()
            adv_images, adv_labels = attack_scheme.generate(images, labels)
            if isinstance(adv_images, Variable):
                adv_images = adv_images.data
            attack_metrics.update('Attack/Time',
                                  time.time() - st, labels.size(0))

            # Lp distance
            diff = torch.abs(
                denormalize(adv_images, self.args.dataset) -
                denormalize(images, self.args.dataset))
            L0 = torch.sum((torch.sum(diff, dim=1) > 1e-3).float().view(
                labels.size(0), -1),
                           dim=1).mean()
            diff = diff.view(labels.size(0), -1)
            L1 = torch.norm(diff, p=1, dim=1).mean()
            L2 = torch.norm(diff, p=2, dim=1).mean()
            Li = torch.max(diff, dim=1)[0].mean()
            dist_metrics.update('L0', float(L0), labels.size(0))
            dist_metrics.update('L1', float(L1), labels.size(0))
            dist_metrics.update('L2', float(L2), labels.size(0))
            dist_metrics.update('Li', float(Li), labels.size(0))

            # Defender
            st = time.time()
            def_images, def_labels = defense_scheme.generate(
                adv_images, adv_labels)
            if isinstance(
                    def_images, Variable
            ):  # FIXME - Variable in Variable out for all methods
                def_images = def_images.data
            defense_metrics.update('Defense/Time',
                                   time.time() - st, labels.size(0))
            self.calc_stats('Attack', adv_images, images, adv_labels, labels,
                            attack_metrics)
            self.calc_stats('Defense', def_images, images, def_labels, labels,
                            defense_metrics)

            # Defense-Inference for shift of original image
            st = time.time()
            def_images_org, _ = defense_scheme.generate(images, labels)
            if isinstance(
                    def_images_org, Variable
            ):  # FIXME - Variable in Variable out for all methods
                def_images_org = def_images_org.data
            outputs = self.model(self.to_var(def_images_org, self.cuda, True))
            outputs = outputs.float()
            _, preds = torch.topk(outputs, 5)

            acc = (labels == preds.data[:, 0]).float().mean()
            top5 = torch.sum(
                (labels.unsqueeze(1).repeat(1, 5) == preds.data).float(),
                dim=1).mean()
            eval_def_metrics.update('Def-Test/Acc', float(acc), labels.size(0))
            eval_def_metrics.update('Def-Test/Top5', float(top5),
                                    labels.size(0))
            eval_def_metrics.update('Def-Test/Time',
                                    time.time() - st, labels.size(0))

            if self.step % self.args.log_step == 0 or self.step == len(
                    self.val_loader):
                self.logger.scalar_summary(eval_metrics.avg, self.step, 'TEST')
                self.logger.scalar_summary(eval_def_metrics.avg, self.step,
                                           'TEST')
                self.logger.scalar_summary(attack_metrics.avg, self.step,
                                           'ATTACK')
                self.logger.scalar_summary(defense_metrics.avg, self.step,
                                           'DEFENSE')
                self.logger.scalar_summary(dist_metrics.avg, self.step, 'DIST')

                defense_rate = eval_metrics.avg[
                    'Test/Acc'] - defense_metrics.avg['Defense/Acc']
                if eval_metrics.avg['Test/Acc'] - attack_metrics.avg[
                        'Attack/Acc']:
                    defense_rate /= eval_metrics.avg[
                        'Test/Acc'] - attack_metrics.avg['Attack/Acc']
                else:
                    defense_rate = 0
                defense_rate = 1 - defense_rate

                defense_top5 = eval_metrics.avg[
                    'Test/Top5'] - defense_metrics.avg['Defense/Top5']
                if eval_metrics.avg['Test/Top5'] - attack_metrics.avg[
                        'Attack/Top5']:
                    defense_top5 /= eval_metrics.avg[
                        'Test/Top5'] - attack_metrics.avg['Attack/Top5']
                else:
                    defense_top5 = 0
                defense_top5 = 1 - defense_top5

                self.logger.log(
                    "Defense Rate Top1: {:5.3f} | Defense Rate Top5: {:5.3f}".
                    format(defense_rate, defense_top5), 'DEFENSE')

            if self.step % self.args.img_log_step == 0:
                image_dict = {
                    'Original':
                    to_np(denormalize(images, self.args.dataset))[0],
                    'Attacked':
                    to_np(denormalize(adv_images, self.args.dataset))[0],
                    'Defensed':
                    to_np(denormalize(def_images, self.args.dataset))[0],
                    'Perturbation':
                    to_np(denormalize(images - adv_images,
                                      self.args.dataset))[0]
                }
                self.logger.image_summary(image_dict, self.step)

    def calc_stats(self, method, gen_images, images, gen_labels, labels,
                   metrics):
        """gen_images: Generated from attacker or defender
        Currently just calculating acc and artifact
        """
        success_rate = 0

        if not isinstance(gen_images, Variable):
            gen_images = self.to_var(gen_images.clone(), self.cuda, True)
        gen_outputs = self.model(gen_images)
        gen_outputs = gen_outputs.float()
        _, gen_preds = torch.topk(F.softmax(gen_outputs, dim=1), 5)

        if isinstance(gen_preds, Variable):
            gen_preds = gen_preds.data
        gen_acc = (labels == gen_preds[:, 0]).float().mean()
        gen_top5 = torch.sum(
            (labels.unsqueeze(1).repeat(1, 5) == gen_preds).float(),
            dim=1).mean()

        metrics.update('{}/Acc'.format(method), float(gen_acc), labels.size(0))
        metrics.update('{}/Top5'.format(method), float(gen_top5),
                       labels.size(0))

    def to_var(self, x, cuda, volatile=False):
        """For CPU inference manual cuda setting is needed
        """
        if cuda:
            x = x.cuda()
        return torch.autograd.Variable(x, volatile=volatile)
Beispiel #6
0
def learn(
        env,
        policy_func,
        args,
        *,
        timesteps_per_batch,  # what to train on
        max_kl,
        cg_iters,
        gamma,
        lam,  # advantage estimation
        entcoeff=0.0,
        cg_damping=1e-2,
        vf_stepsize=3e-4,
        vf_iters=3):
    nworkers = MPI.COMM_WORLD.Get_size()
    rank = MPI.COMM_WORLD.Get_rank()
    np.set_printoptions(precision=3)
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_func("pi", ob_space, ac_space)
    oldpi = policy_func("oldpi", ob_space, ac_space)
    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = U.mean(kloldnew)
    meanent = U.mean(ent)
    entbonus = entcoeff * meanent

    vferr = U.mean(tf.square(pi.vpred - ret))

    ratio = tf.exp(pi.pd.logp(ac) -
                   oldpi.pd.logp(ac))  # advantage * pnew / pold
    surrgain = U.mean(ratio * atarg)

    optimgain = surrgain + entbonus
    losses = [optimgain, meankl, entbonus, surrgain, meanent]
    loss_names = ["optimgain", "meankl", "entloss", "surrgain", "entropy"]

    dist = meankl

    all_var_list = pi.get_trainable_variables()
    var_list = [
        v for v in all_var_list if v.name.split("/")[1].startswith("pol")
    ]
    vf_var_list = [
        v for v in all_var_list if v.name.split("/")[1].startswith("vf")
    ]
    vfadam = MpiAdam(vf_var_list)

    policy_var_list = [
        v for v in all_var_list if v.name.split("/")[0].startswith("pi")
    ]
    saver = MpiSaver(policy_var_list, log_prefix=args.logdir)

    get_flat = U.GetFlat(var_list)
    set_from_flat = U.SetFromFlat(var_list)
    klgrads = tf.gradients(dist, var_list)
    flat_tangent = tf.placeholder(dtype=tf.float32,
                                  shape=[None],
                                  name="flat_tan")
    shapes = [var.get_shape().as_list() for var in var_list]
    start = 0
    tangents = []
    for shape in shapes:
        sz = U.intprod(shape)
        tangents.append(tf.reshape(flat_tangent[start:start + sz], shape))
        start += sz
    gvp = tf.add_n(
        [U.sum(g * tangent) for (g, tangent) in zipsame(klgrads, tangents)])  # pylint: disable=E1111
    fvp = U.flatgrad(gvp, var_list)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])
    compute_losses = U.function([ob, ac, atarg], losses)
    compute_lossandgrad = U.function([ob, ac, atarg], losses +
                                     [U.flatgrad(optimgain, var_list)])
    compute_fvp = U.function([flat_tangent, ob, ac, atarg], fvp)
    compute_vflossandgrad = U.function([ob, ret],
                                       U.flatgrad(vferr, vf_var_list))

    @contextmanager
    def timed(msg):
        if rank == 0:
            print(colorize(msg, color='magenta'))
            tstart = time.time()
            yield
            print(
                colorize("done in %.3f seconds" % (time.time() - tstart),
                         color='magenta'))
        else:
            yield

    def allmean(x):
        assert isinstance(x, np.ndarray)
        out = np.empty_like(x)
        MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM)
        out /= nworkers
        return out

    U.initialize()
    saver.restore(restore_from=args.restore_actor_from)
    th_init = get_flat()
    MPI.COMM_WORLD.Bcast(th_init, root=0)
    set_from_flat(th_init)
    vfadam.sync()
    print("Init param sum", th_init.sum(), flush=True)

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi,
                                     env,
                                     args,
                                     timesteps_per_batch,
                                     stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=40)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=40)  # rolling buffer for episode rewards

    args.logdir = "{}/thread_{}".format(args.logdir, args.thread)
    logger = Logger(args.logdir)

    while time.time() - tstart < 86400 * args.max_train_days:
        # logger.log("********** Iteration %i ************" % iters_so_far)
        meanlosses = [0] * len(loss_names)
        with timed("sampling"):
            seg = seg_gen.__next__()
        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[
            "tdlamret"]
        vpredbefore = seg["vpred"]  # predicted value function before udpate
        atarg = (atarg - atarg.mean()
                 ) / atarg.std()  # standardized advantage function estimate

        if hasattr(pi, "ret_rms"): pi.ret_rms.update(tdlamret)
        if hasattr(pi, "ob_rms"):
            pi.ob_rms.update(ob)  # update running mean/std for policy

        segargs = seg["ob"], seg["ac"], seg["adv"]
        fvpargs = [arr[::5] for arr in segargs]

        def fisher_vector_product(p):
            return allmean(compute_fvp(p, *fvpargs)) + cg_damping * p

        assign_old_eq_new()  # set old parameter values to new parameter values
        with timed("computegrad"):
            *lossbefore, g = compute_lossandgrad(*segargs)
        lossbefore = allmean(np.array(lossbefore))
        g = allmean(g)
        if np.allclose(g, 0):
            pass
        #     logger.log("Got zero gradient. not updating")
        else:
            with timed("cg"):
                stepdir = cg(fisher_vector_product,
                             g,
                             cg_iters=cg_iters,
                             verbose=rank == 0)
            assert np.isfinite(stepdir).all()
            shs = .5 * stepdir.dot(fisher_vector_product(stepdir))
            lm = np.sqrt(shs / max_kl)
            # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g))
            fullstep = stepdir / lm
            expectedimprove = g.dot(fullstep)
            surrbefore = lossbefore[0]
            stepsize = 1.0
            thbefore = get_flat()
            for _ in range(10):
                thnew = thbefore + fullstep * stepsize
                set_from_flat(thnew)
                meanlosses = surr, kl, *_ = allmean(
                    np.array(compute_losses(*segargs)))
                improve = surr - surrbefore
                # logger.log("Expected: %.3f Actual: %.3f" % (expectedimprove, improve))
                # if not np.isfinite(meanlosses).all():
                #     logger.log("Got non-finite value of losses -- bad!")
                # elif kl > max_kl * 1.5:
                #     logger.log("violated KL constraint. shrinking step.")
                # elif improve < 0:
                #     logger.log("surrogate didn't improve. shrinking step.")
                # else:
                #     logger.log("Stepsize OK!")
                #     break
                stepsize *= .5
            else:
                # logger.log("couldn't compute a good step")
                set_from_flat(thbefore)
            if nworkers > 1 and iters_so_far % 20 == 0:
                paramsums = MPI.COMM_WORLD.allgather(
                    (thnew.sum(), vfadam.getflat().sum()))  # list of tuples
                assert all(
                    np.allclose(ps, paramsums[0]) for ps in paramsums[1:])

        with timed("vf"):
            for _ in range(vf_iters):
                for (mbob, mbret) in dataset.iterbatches(
                    (seg["ob"], seg["tdlamret"]),
                        include_final_partial_batch=False,
                        batch_size=64):
                    g = allmean(compute_vflossandgrad(mbob, mbret))
                    vfadam.update(g, vf_stepsize)

        saver.sync()

        lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)

        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1

        # Logging
        logger.scalar_summary("episodes", len(lens), iters_so_far)

        for (lossname, lossval) in zip(loss_names, meanlosses):
            logger.scalar_summary(lossname, lossval, episodes_so_far)

        logger.scalar_summary("ev_tdlam_before",
                              explained_variance(vpredbefore, tdlamret),
                              episodes_so_far)

        logger.scalar_summary("step", np.mean(lenbuffer), episodes_so_far)
        logger.scalar_summary("reward", np.mean(rewbuffer), episodes_so_far)
        logger.scalar_summary("best reward", np.max(rewbuffer),
                              episodes_so_far)

        elapsed_time = time.time() - tstart

        logger.scalar_summary("episode per minute",
                              episodes_so_far / elapsed_time * 60,
                              episodes_so_far)
        logger.scalar_summary("step per second",
                              timesteps_so_far / elapsed_time, episodes_so_far)
Beispiel #7
0
def learn(
    env,
    policy_func,
    args,
    *,
    timesteps_per_batch,  # timesteps per actor per update
    clip_param,
    entcoeff,  # clipping parameter epsilon, entropy coeff
    optim_epochs,
    optim_stepsize,
    optim_batchsize,  # optimization hypers
    gamma,
    lam,  # advantage estimation
    adam_epsilon=1e-5,
    schedule='constant'
):  # annealing for stepsize parameters (epsilon and adam),
    # Setup losses and stuff
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_func("pi", ob_space,
                     ac_space)  # Construct network for new policy
    oldpi = policy_func("oldpi", ob_space, ac_space)  # Network for old policy
    atarg = tf.placeholder(
        dtype=tf.float32,
        shape=[None])  # Target advantage function (if applicable)
    ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

    lrmult = tf.placeholder(
        name='lrmult', dtype=tf.float32,
        shape=[])  # learning rate multiplier, updated with schedule
    clip_param = clip_param * lrmult  # Annealed cliping parameter epislon

    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])

    kloldnew = oldpi.pd.kl(pi.pd)
    ent = pi.pd.entropy()
    meankl = U.mean(kloldnew)
    meanent = U.mean(ent)
    pol_entpen = (-entcoeff) * meanent

    ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac))  # pnew / pold
    surr1 = ratio * atarg  # surrogate from conservative policy iteration
    surr2 = U.clip(ratio, 1.0 - clip_param, 1.0 + clip_param) * atarg  #
    pol_surr = -U.mean(tf.minimum(
        surr1, surr2))  # PPO's pessimistic surrogate (L^CLIP)
    vf_loss = U.mean(tf.square(pi.vpred - ret))
    total_loss = pol_surr + pol_entpen + vf_loss
    losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
    loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

    var_list = pi.get_trainable_variables()
    lossandgrad = U.function([ob, ac, atarg, ret, lrmult],
                             losses + [U.flatgrad(total_loss, var_list)])
    adam = MpiAdam(var_list, epsilon=adam_epsilon)
    policy_var_list = [
        v for v in var_list if v.name.split("/")[0].startswith("pi")
    ]
    saver = MpiSaver(policy_var_list, log_prefix=args.logdir)

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])
    compute_losses = U.function([ob, ac, atarg, ret, lrmult], losses)

    U.initialize()
    saver.restore(restore_from=args.restore_actor_from)
    adam.sync()

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = traj_segment_generator(pi,
                                     env,
                                     args,
                                     timesteps_per_batch,
                                     stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=100)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=100)  # rolling buffer for episode rewards

    # max_timesteps = 1e10
    cur_lrmult = 1.0

    args.logdir = "{}/thread_{}".format(args.logdir, args.thread)
    logger = Logger(args.logdir)

    while time.time() - tstart < 86400 * args.max_train_days:
        # if schedule == 'constant':
        #     cur_lrmult = 1.0
        # elif schedule == 'linear':
        #     cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
        # else:
        #     raise NotImplementedError

        # logger.log("********** Iteration %i ************" % iters_so_far)

        seg = seg_gen.__next__()
        add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[
            "tdlamret"]
        vpredbefore = seg["vpred"]  # predicted value function before udpate
        atarg = (atarg - atarg.mean()
                 ) / atarg.std()  # standardized advantage function estimate
        d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret),
                    shuffle=True)
        optim_batchsize = optim_batchsize or ob.shape[0]

        if hasattr(pi, "ob_rms"):
            pi.ob_rms.update(ob)  # update running mean/std for policy

        assign_old_eq_new()  # set old parameter values to new parameter values
        # logger.log("Optimizing...")
        # logger.log(fmt_row(13, loss_names))
        # Here we do a bunch of optimization epochs over the data
        for _ in range(optim_epochs):
            losses = [
            ]  # list of tuples, each of which gives the loss for a minibatch
            for batch in d.iterate_once(optim_batchsize):
                *newlosses, g = lossandgrad(batch["ob"], batch["ac"],
                                            batch["atarg"], batch["vtarg"],
                                            cur_lrmult)
                adam.update(g, optim_stepsize * cur_lrmult)
                losses.append(newlosses)
            # logger.log(fmt_row(13, np.mean(losses, axis=0)))

        saver.sync()
        # logger.log("Evaluating losses...")
        losses = []
        for batch in d.iterate_once(optim_batchsize):
            newlosses = compute_losses(batch["ob"], batch["ac"],
                                       batch["atarg"], batch["vtarg"],
                                       cur_lrmult)
            losses.append(newlosses)
        meanlosses, _, _ = mpi_moments(losses, axis=0)
        # logger.log(fmt_row(13, meanlosses))

        lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews = map(flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)

        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1

        # Logging
        logger.scalar_summary("episodes", len(lens), iters_so_far)

        for (lossname, lossval) in zip(loss_names, meanlosses):
            logger.scalar_summary(lossname, lossval, episodes_so_far)

        logger.scalar_summary("ev_tdlam_before",
                              explained_variance(vpredbefore, tdlamret),
                              episodes_so_far)

        logger.scalar_summary("step", np.mean(lenbuffer), episodes_so_far)
        logger.scalar_summary("reward", np.mean(rewbuffer), episodes_so_far)
        logger.scalar_summary("best reward", np.max(rewbuffer),
                              episodes_so_far)

        elapsed_time = time.time() - tstart

        logger.scalar_summary("episode per minute",
                              episodes_so_far / elapsed_time * 60,
                              episodes_so_far)
        logger.scalar_summary("step per second",
                              timesteps_so_far / elapsed_time, episodes_so_far)