def train_single_thread( actor, critic, target_actor, target_critic, args, prepare_fn, global_episode, global_update_step, episodes_queue): workerseed = args.seed + 241 * args.thread set_global_seeds(workerseed) args.logdir = "{}/thread_{}".format(args.logdir, args.thread) create_if_need(args.logdir) _, update_fn, save_fn = prepare_fn(actor, critic, target_actor, target_critic, args) logger = Logger(args.logdir) buffer = create_buffer(args) if args.prioritized_replay: beta_deacy_fn = create_decay_fn( "linear", initial_value=args.prioritized_replay_beta0, final_value=1.0, max_step=args.max_update_steps) actor_learning_rate_decay_fn = create_decay_fn( "linear", initial_value=args.actor_lr, final_value=args.actor_lr_end, max_step=args.max_update_steps) critic_learning_rate_decay_fn = create_decay_fn( "linear", initial_value=args.critic_lr, final_value=args.critic_lr_end, max_step=args.max_update_steps) update_step = 0 received_examples = 1 # just hack while global_episode.value < args.max_episodes * (args.num_threads - args.num_train_threads) \ and global_update_step.value < args.max_update_steps * args.num_train_threads: actor_lr = actor_learning_rate_decay_fn(update_step) critic_lr = critic_learning_rate_decay_fn(update_step) actor_lr = min(args.actor_lr, max(args.actor_lr_end, actor_lr)) critic_lr = min(args.critic_lr, max(args.critic_lr_end, critic_lr)) while True: try: replay = episodes_queue.get_nowait() for (observation, action, reward, next_observation, done) in replay: buffer.add(observation, action, reward, next_observation, done) received_examples += len(replay) except py_queue.Empty: break if len(buffer) >= args.train_steps: if args.prioritized_replay: beta = beta_deacy_fn(update_step) beta = min(1.0, max(args.prioritized_replay_beta0, beta)) (tr_observations, tr_actions, tr_rewards, tr_next_observations, tr_dones, weights, batch_idxes) = \ buffer.sample( batch_size=args.batch_size, beta=beta) else: (tr_observations, tr_actions, tr_rewards, tr_next_observations, tr_dones) = \ buffer.sample(batch_size=args.batch_size) weights, batch_idxes = np.ones_like(tr_rewards), None step_metrics, step_info = update_fn( tr_observations, tr_actions, tr_rewards, tr_next_observations, tr_dones, weights, actor_lr, critic_lr) update_step += 1 global_update_step.value += 1 if args.prioritized_replay: new_priorities = np.abs(step_info["td_error"]) + 1e-6 buffer.update_priorities(batch_idxes, new_priorities) for key, value in step_metrics.items(): value = to_numpy(value)[0] logger.scalar_summary(key, value, update_step) logger.scalar_summary("actor lr", actor_lr, update_step) logger.scalar_summary("critic lr", critic_lr, update_step) if update_step % args.save_step == 0: save_fn(update_step) else: time.sleep(1) logger.scalar_summary("buffer size", len(buffer), global_episode.value) logger.scalar_summary( "updates per example", update_step * args.batch_size / received_examples, global_episode.value) save_fn(update_step) raise KeyboardInterrupt
def play_single_thread( actor, critic, target_actor, target_critic, args, prepare_fn, global_episode, global_update_step, episodes_queue, best_reward): workerseed = args.seed + 241 * args.thread set_global_seeds(workerseed) args.logdir = "{}/thread_{}".format(args.logdir, args.thread) create_if_need(args.logdir) act_fn, _, save_fn = prepare_fn(actor, critic, target_actor, target_critic, args) logger = Logger(args.logdir) env = create_env(args) random_process = create_random_process(args) epsilon_cycle_len = random.randint(args.epsilon_cycle_len // 2, args.epsilon_cycle_len * 2) epsilon_decay_fn = create_decay_fn( "cycle", initial_value=args.initial_epsilon, final_value=args.final_epsilon, cycle_len=epsilon_cycle_len, num_cycles=args.max_episodes // epsilon_cycle_len) episode = 1 step = 0 start_time = time.time() while global_episode.value < args.max_episodes * (args.num_threads - args.num_train_threads) \ and global_update_step.value < args.max_update_steps * args.num_train_threads: if episode % 100 == 0: env = create_env(args) seed = random.randrange(2 ** 32 - 2) epsilon = min(args.initial_epsilon, max(args.final_epsilon, epsilon_decay_fn(episode))) episode_metrics = { "reward": 0.0, "step": 0, "epsilon": epsilon } observation = env.reset(seed=seed, difficulty=args.difficulty) random_process.reset_states() done = False replay = [] while not done: action = act_fn(observation, noise=epsilon * random_process.sample()) next_observation, reward, done, _ = env.step(action) replay.append((observation, action, reward, next_observation, done)) episode_metrics["reward"] += reward episode_metrics["step"] += 1 observation = next_observation episodes_queue.put(replay) episode += 1 global_episode.value += 1 if episode_metrics["reward"] > best_reward.value: best_reward.value = episode_metrics["reward"] logger.scalar_summary("best reward", best_reward.value, episode) if episode_metrics["reward"] > 15.0 * args.reward_scale: save_fn(episode) step += episode_metrics["step"] elapsed_time = time.time() - start_time for key, value in episode_metrics.items(): logger.scalar_summary(key, value, episode) logger.scalar_summary( "episode per minute", episode / elapsed_time * 60, episode) logger.scalar_summary( "step per second", step / elapsed_time, episode) if elapsed_time > 86400 * args.max_train_days: global_episode.value = args.max_episodes * (args.num_threads - args.num_train_threads) + 1 raise KeyboardInterrupt
def train_multi_thread(actor, critic, target_actor, target_critic, args, prepare_fn, best_reward): workerseed = args.seed + 241 * args.thread set_global_seeds(workerseed) args.logdir = "{}/thread_{}".format(args.logdir, args.thread) create_if_need(args.logdir) act_fn, update_fn, save_fn = prepare_fn(actor, critic, target_actor, target_critic, args) logger = Logger(args.logdir) buffer = create_buffer(args) if args.prioritized_replay: beta_deacy_fn = create_decay_fn( "linear", initial_value=args.prioritized_replay_beta0, final_value=1.0, max_step=args.max_episodes) env = create_env(args) random_process = create_random_process(args) actor_learning_rate_decay_fn = create_decay_fn( "linear", initial_value=args.actor_lr, final_value=args.actor_lr_end, max_step=args.max_episodes) critic_learning_rate_decay_fn = create_decay_fn( "linear", initial_value=args.critic_lr, final_value=args.critic_lr_end, max_step=args.max_episodes) epsilon_cycle_len = random.randint(args.epsilon_cycle_len // 2, args.epsilon_cycle_len * 2) epsilon_decay_fn = create_decay_fn( "cycle", initial_value=args.initial_epsilon, final_value=args.final_epsilon, cycle_len=epsilon_cycle_len, num_cycles=args.max_episodes // epsilon_cycle_len) episode = 0 step = 0 start_time = time.time() while episode < args.max_episodes: if episode % 100 == 0: env = create_env(args) seed = random.randrange(2 ** 32 - 2) actor_lr = actor_learning_rate_decay_fn(episode) critic_lr = critic_learning_rate_decay_fn(episode) epsilon = min(args.initial_epsilon, max(args.final_epsilon, epsilon_decay_fn(episode))) episode_metrics = { "value_loss": 0.0, "policy_loss": 0.0, "reward": 0.0, "step": 0, "epsilon": epsilon } observation = env.reset(seed=seed, difficulty=args.difficulty) random_process.reset_states() done = False while not done: action = act_fn(observation, noise=epsilon*random_process.sample()) next_observation, reward, done, _ = env.step(action) buffer.add(observation, action, reward, next_observation, done) episode_metrics["reward"] += reward episode_metrics["step"] += 1 if len(buffer) >= args.train_steps: if args.prioritized_replay: (tr_observations, tr_actions, tr_rewards, tr_next_observations, tr_dones, weights, batch_idxes) = \ buffer.sample(batch_size=args.batch_size, beta=beta_deacy_fn(episode)) else: (tr_observations, tr_actions, tr_rewards, tr_next_observations, tr_dones) = \ buffer.sample(batch_size=args.batch_size) weights, batch_idxes = np.ones_like(tr_rewards), None step_metrics, step_info = update_fn( tr_observations, tr_actions, tr_rewards, tr_next_observations, tr_dones, weights, actor_lr, critic_lr) if args.prioritized_replay: new_priorities = np.abs(step_info["td_error"]) + 1e-6 buffer.update_priorities(batch_idxes, new_priorities) for key, value in step_metrics.items(): value = to_numpy(value)[0] episode_metrics[key] += value observation = next_observation episode += 1 if episode_metrics["reward"] > 15.0 * args.reward_scale \ and episode_metrics["reward"] > best_reward.value: best_reward.value = episode_metrics["reward"] logger.scalar_summary("best reward", best_reward.value, episode) save_fn(episode) step += episode_metrics["step"] elapsed_time = time.time() - start_time for key, value in episode_metrics.items(): value = value if "loss" not in key else value / episode_metrics["step"] logger.scalar_summary(key, value, episode) logger.scalar_summary( "episode per minute", episode / elapsed_time * 60, episode) logger.scalar_summary( "step per second", step / elapsed_time, episode) logger.scalar_summary("actor lr", actor_lr, episode) logger.scalar_summary("critic lr", critic_lr, episode) if episode % args.save_step == 0: save_fn(episode) if elapsed_time > 86400 * args.max_train_days: episode = args.max_episodes + 1 save_fn(episode) raise KeyboardInterrupt
class Trainer: """ Train and Validation with single GPU """ def __init__(self, train_loader, val_loader, args): self.train_loader = train_loader self.val_loader = val_loader self.args = args self.model = get_model(args) self.epochs = args.epochs self.total_step = len(train_loader) * args.epochs self.step = 0 self.epoch = 0 self.start_epoch = 1 self.lr = args.learning_rate self.best_acc = 0 # Log self.log_path = ( PROJECT_ROOT / Path(SAVE_DIR) / Path(datetime.now().strftime("%Y%m%d%H%M%S") + "-") ).as_posix() self.log_path = Path(self.get_dirname(self.log_path, args)) if not Path.exists(self.log_path): Path(self.log_path).mkdir(parents=True, exist_ok=True) self.logger = Logger("train", self.log_path, args.verbose) self.logger.log("Checkpoint files will be saved in {}".format(self.log_path)) self.logger.add_level('STEP', 21, 'green') self.logger.add_level('EPOCH', 22, 'cyan') self.logger.add_level('EVAL', 23, 'yellow') self.criterion = nn.CrossEntropyLoss() if self.args.cuda: self.criterion = self.criterion.cuda() if args.half: self.model.half() self.criterion.half() params = self.model.parameters() self.optimizer = get_optimizer(args.optimizer, params, args) def train(self): self.eval() for self.epoch in range(self.start_epoch, self.args.epochs+1): self.adjust_learning_rate([int(self.args.epochs/2), int(self.args.epochs*3/4)], factor=0.1) self.train_epoch() self.eval() self.logger.writer.export_scalars_to_json( self.log_path.as_posix() + "/scalars-{}-{}-{}.json".format( self.args.model, self.args.seed, self.args.activation ) ) self.logger.writer.close() def train_epoch(self): self.model.train() eval_metrics = EvaluationMetrics(['Loss', 'Acc', 'Time']) for i, (images, labels) in enumerate(self.train_loader): st = time.time() self.step += 1 images = torch.autograd.Variable(images) labels = torch.autograd.Variable(labels) if self.args.cuda: images = images.cuda() labels = labels.cuda() if self.args.half: images = images.half() outputs, loss = self.compute_loss(images, labels) self.optimizer.zero_grad() loss.backward() self.optimizer.step() outputs = outputs.float() loss = loss.float() elapsed_time = time.time() - st _, preds = torch.max(outputs, 1) accuracy = (labels == preds.squeeze()).float().mean() batch_size = labels.size(0) eval_metrics.update('Loss', float(loss), batch_size) eval_metrics.update('Acc', float(accuracy), batch_size) eval_metrics.update('Time', elapsed_time, batch_size) if self.step % self.args.log_step == 0: self.logger.scalar_summary(eval_metrics.val, self.step, 'STEP') # Histogram of parameters for tag, p in self.model.named_parameters(): tag = tag.split(".") tag = "Train_{}".format(tag[0]) + "/" + "/".join(tag[1:]) try: self.logger.writer.add_histogram(tag, p.clone().cpu().data.numpy(), self.step) self.logger.writer.add_histogram(tag+'/grad', p.grad.clone().cpu().data.numpy(), self.step) except Exception as e: print("Check if variable {} is not used: {}".format(tag, e)) self.logger.scalar_summary(eval_metrics.avg, self.step, 'EPOCH') def eval(self): self.model.eval() eval_metrics = EvaluationMetrics(['Loss', 'Acc', 'Time']) for i, (images, labels) in enumerate(self.val_loader): st = time.time() images = torch.autograd.Variable(images) labels = torch.autograd.Variable(labels) if self.args.cuda: images = images.cuda() labels = labels.cuda() if self.args.half: images = images.half() outputs, loss = self.compute_loss(images, labels) outputs = outputs.float() loss = loss.float() elapsed_time = time.time() - st _, preds = torch.max(outputs, 1) accuracy = (labels == preds.squeeze()).float().mean() batch_size = labels.size(0) eval_metrics.update('Loss', float(loss), batch_size) eval_metrics.update('Acc', float(accuracy), batch_size) eval_metrics.update('Time', elapsed_time, batch_size) # Save best model if eval_metrics.avg['Acc'] > self.best_acc: self.save() self.logger.log("Saving best model: epoch={}".format(self.epoch)) self.best_acc = eval_metrics.avg['Acc'] self.maybe_delete_old_pth(log_path=self.log_path.as_posix(), max_to_keep=1) self.logger.scalar_summary(eval_metrics.avg, self.step, 'EVAL') def get_dirname(self, path, args): path += "{}-".format(getattr(args, 'dataset')) path += "{}-".format(getattr(args, 'seed')) path += "{}".format(getattr(args, 'model')) return path def save(self, filename=None): if filename is None: filename = os.path.join(self.log_path, 'model-{}.pth'.format(self.epoch)) torch.save({ 'model': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'epoch': self.start_epoch, 'best_acc': self.best_acc, 'args': self.args }, filename) def load(self, filename=None): if filename is None: filename = self.log_path S = torch.load(filename) if self.args.cuda else torch.load(filename, map_location=lambda storage, location: storage) self.model.load_state_dict(S['model']) self.optimizer.load_state_dict(S['optimizer']) self.epoch = S['epoch'] self.best_acc = S['best_acc'] self.args = S['args'] def maybe_delete_old_pth(self, log_path, max_to_keep): """Model filename must end with xxx-xxx-[epoch].pth """ # filename and time pths = [(f, int(f[:-4].split("-")[-1])) for f in os.listdir(log_path) if f.endswith('.pth')] if len(pths) > max_to_keep: sorted_pths = sorted(pths, key=lambda tup: tup[1]) for i in range(len(pths) - max_to_keep): os.remove(os.path.join(log_path, sorted_pths[i][0])) def show_current_model(self): print("\n".join("{}: {}".format(k, v) for k, v in sorted(vars(self.args).items()))) model_parameters = filter(lambda p: p.requires_grad, self.model.parameters()) total_params = np.sum([np.prod(p.size()) for p in model_parameters]) print('%s\n\n'%(type(self.model))) print('%s\n\n'%(inspect.getsource(self.model.__init__))) print('%s\n\n'%(inspect.getsource(self.model.forward))) # Total 95 print("*"*40 + "%10s" % self.args.model + "*"*45) print("*"*40 + "PARAM INFO" + "*"*45) print("-"*95) print("| %40s | %25s | %20s |" % ("Param Name", "Shape", "Number of Params")) print("-"*95) for name, param in self.model.named_parameters(): if param.requires_grad: print("| %40s | %25s | %20d |" % (name, list(param.size()), np.prod(param.size()))) print("-"*95) print("Total Params: %d" % (total_params)) print("*"*95) def adjust_learning_rate(self, milestone, factor=0.1): if self.epoch in milestone: self.lr *= factor for param_group in self.optimizer.param_groups: param_group['lr'] = self.lr def compute_loss(self, images, labels): outputs = self.model(images) loss = self.criterion(outputs, labels) return outputs, loss
class Defender(Trainer): """ Perform various adversarial attacks and defense on a pretrained model Scheme generates Tensor, not Variable """ def __init__(self, val_loader, args, **kwargs): self.val_loader = val_loader self.args = args self.model = get_model(args) self.step = 0 self.cuda = self.args.cuda self.log_path = ( PROJECT_ROOT / Path("experiments") / Path(datetime.now().strftime("%Y%m%d%H%M%S") + "-")).as_posix() self.log_path = Path(self.get_dirname(self.log_path, args)) if not Path.exists(self.log_path): Path(self.log_path).mkdir(parents=True, exist_ok=True) self.logger = Logger("defense", self.log_path, args.verbose) self.logger.log("Checkpoint files will be saved in {}".format( self.log_path)) self.logger.add_level("ATTACK", 21, 'yellow') self.logger.add_level("DEFENSE", 22, 'cyan') self.logger.add_level("TEST", 23, 'white') self.logger.add_level("DIST", 11, 'white') self.kwargs = kwargs if args.domain_restrict: self.artifact = get_artifact(self.model, val_loader, args) self.kwargs['artifact'] = self.artifact def defend(self): self.model.eval() defense_scheme = getattr(defenses, self.args.defense)(self.model, self.args, **self.kwargs) source = self.model if self.args.source is not None and (self.args.ckpt_name != self.args.ckpt_src): target = self.args.ckpt_name self.args.model = self.args.source self.args.ckpt_name = self.args.ckpt_src source = get_model(self.args) self.logger.log("Transfer attack from {} -> {}".format( self.args.ckpt_src, target)) attack_scheme = getattr(attacks, self.args.attack)(source, self.args, **self.kwargs) eval_metrics = EvaluationMetrics( ['Test/Acc', 'Test/Top5', 'Test/Time']) eval_def_metrics = EvaluationMetrics( ['Def-Test/Acc', 'Def-Test/Top5', 'Def-Test/Time']) attack_metrics = EvaluationMetrics( ['Attack/Acc', 'Attack/Top5', 'Attack/Time']) defense_metrics = EvaluationMetrics( ['Defense/Acc', 'Defense/Top5', 'Defense/Time']) dist_metrics = EvaluationMetrics(['L0', 'L1', 'L2', 'Li']) for i, (images, labels) in enumerate(self.val_loader): self.step += 1 if self.cuda: images = images.cuda() labels = labels.cuda() if self.args.half: images = images.half() # Inference st = time.time() outputs = self.model(self.to_var(images, self.cuda, True)) outputs = outputs.float() _, preds = torch.topk(outputs, 5) acc = (labels == preds.data[:, 0]).float().mean() top5 = torch.sum( (labels.unsqueeze(1).repeat(1, 5) == preds.data).float(), dim=1).mean() eval_metrics.update('Test/Acc', float(acc), labels.size(0)) eval_metrics.update('Test/Top5', float(top5), labels.size(0)) eval_metrics.update('Test/Time', time.time() - st, labels.size(0)) # Attacker st = time.time() adv_images, adv_labels = attack_scheme.generate(images, labels) if isinstance(adv_images, Variable): adv_images = adv_images.data attack_metrics.update('Attack/Time', time.time() - st, labels.size(0)) # Lp distance diff = torch.abs( denormalize(adv_images, self.args.dataset) - denormalize(images, self.args.dataset)) L0 = torch.sum((torch.sum(diff, dim=1) > 1e-3).float().view( labels.size(0), -1), dim=1).mean() diff = diff.view(labels.size(0), -1) L1 = torch.norm(diff, p=1, dim=1).mean() L2 = torch.norm(diff, p=2, dim=1).mean() Li = torch.max(diff, dim=1)[0].mean() dist_metrics.update('L0', float(L0), labels.size(0)) dist_metrics.update('L1', float(L1), labels.size(0)) dist_metrics.update('L2', float(L2), labels.size(0)) dist_metrics.update('Li', float(Li), labels.size(0)) # Defender st = time.time() def_images, def_labels = defense_scheme.generate( adv_images, adv_labels) if isinstance( def_images, Variable ): # FIXME - Variable in Variable out for all methods def_images = def_images.data defense_metrics.update('Defense/Time', time.time() - st, labels.size(0)) self.calc_stats('Attack', adv_images, images, adv_labels, labels, attack_metrics) self.calc_stats('Defense', def_images, images, def_labels, labels, defense_metrics) # Defense-Inference for shift of original image st = time.time() def_images_org, _ = defense_scheme.generate(images, labels) if isinstance( def_images_org, Variable ): # FIXME - Variable in Variable out for all methods def_images_org = def_images_org.data outputs = self.model(self.to_var(def_images_org, self.cuda, True)) outputs = outputs.float() _, preds = torch.topk(outputs, 5) acc = (labels == preds.data[:, 0]).float().mean() top5 = torch.sum( (labels.unsqueeze(1).repeat(1, 5) == preds.data).float(), dim=1).mean() eval_def_metrics.update('Def-Test/Acc', float(acc), labels.size(0)) eval_def_metrics.update('Def-Test/Top5', float(top5), labels.size(0)) eval_def_metrics.update('Def-Test/Time', time.time() - st, labels.size(0)) if self.step % self.args.log_step == 0 or self.step == len( self.val_loader): self.logger.scalar_summary(eval_metrics.avg, self.step, 'TEST') self.logger.scalar_summary(eval_def_metrics.avg, self.step, 'TEST') self.logger.scalar_summary(attack_metrics.avg, self.step, 'ATTACK') self.logger.scalar_summary(defense_metrics.avg, self.step, 'DEFENSE') self.logger.scalar_summary(dist_metrics.avg, self.step, 'DIST') defense_rate = eval_metrics.avg[ 'Test/Acc'] - defense_metrics.avg['Defense/Acc'] if eval_metrics.avg['Test/Acc'] - attack_metrics.avg[ 'Attack/Acc']: defense_rate /= eval_metrics.avg[ 'Test/Acc'] - attack_metrics.avg['Attack/Acc'] else: defense_rate = 0 defense_rate = 1 - defense_rate defense_top5 = eval_metrics.avg[ 'Test/Top5'] - defense_metrics.avg['Defense/Top5'] if eval_metrics.avg['Test/Top5'] - attack_metrics.avg[ 'Attack/Top5']: defense_top5 /= eval_metrics.avg[ 'Test/Top5'] - attack_metrics.avg['Attack/Top5'] else: defense_top5 = 0 defense_top5 = 1 - defense_top5 self.logger.log( "Defense Rate Top1: {:5.3f} | Defense Rate Top5: {:5.3f}". format(defense_rate, defense_top5), 'DEFENSE') if self.step % self.args.img_log_step == 0: image_dict = { 'Original': to_np(denormalize(images, self.args.dataset))[0], 'Attacked': to_np(denormalize(adv_images, self.args.dataset))[0], 'Defensed': to_np(denormalize(def_images, self.args.dataset))[0], 'Perturbation': to_np(denormalize(images - adv_images, self.args.dataset))[0] } self.logger.image_summary(image_dict, self.step) def calc_stats(self, method, gen_images, images, gen_labels, labels, metrics): """gen_images: Generated from attacker or defender Currently just calculating acc and artifact """ success_rate = 0 if not isinstance(gen_images, Variable): gen_images = self.to_var(gen_images.clone(), self.cuda, True) gen_outputs = self.model(gen_images) gen_outputs = gen_outputs.float() _, gen_preds = torch.topk(F.softmax(gen_outputs, dim=1), 5) if isinstance(gen_preds, Variable): gen_preds = gen_preds.data gen_acc = (labels == gen_preds[:, 0]).float().mean() gen_top5 = torch.sum( (labels.unsqueeze(1).repeat(1, 5) == gen_preds).float(), dim=1).mean() metrics.update('{}/Acc'.format(method), float(gen_acc), labels.size(0)) metrics.update('{}/Top5'.format(method), float(gen_top5), labels.size(0)) def to_var(self, x, cuda, volatile=False): """For CPU inference manual cuda setting is needed """ if cuda: x = x.cuda() return torch.autograd.Variable(x, volatile=volatile)
def learn( env, policy_func, args, *, timesteps_per_batch, # what to train on max_kl, cg_iters, gamma, lam, # advantage estimation entcoeff=0.0, cg_damping=1e-2, vf_stepsize=3e-4, vf_iters=3): nworkers = MPI.COMM_WORLD.Get_size() rank = MPI.COMM_WORLD.Get_rank() np.set_printoptions(precision=3) # Setup losses and stuff # ---------------------------------------- ob_space = env.observation_space ac_space = env.action_space pi = policy_func("pi", ob_space, ac_space) oldpi = policy_func("oldpi", ob_space, ac_space) atarg = tf.placeholder( dtype=tf.float32, shape=[None]) # Target advantage function (if applicable) ret = tf.placeholder(dtype=tf.float32, shape=[None]) # Empirical return ob = U.get_placeholder_cached(name="ob") ac = pi.pdtype.sample_placeholder([None]) kloldnew = oldpi.pd.kl(pi.pd) ent = pi.pd.entropy() meankl = U.mean(kloldnew) meanent = U.mean(ent) entbonus = entcoeff * meanent vferr = U.mean(tf.square(pi.vpred - ret)) ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac)) # advantage * pnew / pold surrgain = U.mean(ratio * atarg) optimgain = surrgain + entbonus losses = [optimgain, meankl, entbonus, surrgain, meanent] loss_names = ["optimgain", "meankl", "entloss", "surrgain", "entropy"] dist = meankl all_var_list = pi.get_trainable_variables() var_list = [ v for v in all_var_list if v.name.split("/")[1].startswith("pol") ] vf_var_list = [ v for v in all_var_list if v.name.split("/")[1].startswith("vf") ] vfadam = MpiAdam(vf_var_list) policy_var_list = [ v for v in all_var_list if v.name.split("/")[0].startswith("pi") ] saver = MpiSaver(policy_var_list, log_prefix=args.logdir) get_flat = U.GetFlat(var_list) set_from_flat = U.SetFromFlat(var_list) klgrads = tf.gradients(dist, var_list) flat_tangent = tf.placeholder(dtype=tf.float32, shape=[None], name="flat_tan") shapes = [var.get_shape().as_list() for var in var_list] start = 0 tangents = [] for shape in shapes: sz = U.intprod(shape) tangents.append(tf.reshape(flat_tangent[start:start + sz], shape)) start += sz gvp = tf.add_n( [U.sum(g * tangent) for (g, tangent) in zipsame(klgrads, tangents)]) # pylint: disable=E1111 fvp = U.flatgrad(gvp, var_list) assign_old_eq_new = U.function( [], [], updates=[ tf.assign(oldv, newv) for (oldv, newv) in zipsame(oldpi.get_variables(), pi.get_variables()) ]) compute_losses = U.function([ob, ac, atarg], losses) compute_lossandgrad = U.function([ob, ac, atarg], losses + [U.flatgrad(optimgain, var_list)]) compute_fvp = U.function([flat_tangent, ob, ac, atarg], fvp) compute_vflossandgrad = U.function([ob, ret], U.flatgrad(vferr, vf_var_list)) @contextmanager def timed(msg): if rank == 0: print(colorize(msg, color='magenta')) tstart = time.time() yield print( colorize("done in %.3f seconds" % (time.time() - tstart), color='magenta')) else: yield def allmean(x): assert isinstance(x, np.ndarray) out = np.empty_like(x) MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM) out /= nworkers return out U.initialize() saver.restore(restore_from=args.restore_actor_from) th_init = get_flat() MPI.COMM_WORLD.Bcast(th_init, root=0) set_from_flat(th_init) vfadam.sync() print("Init param sum", th_init.sum(), flush=True) # Prepare for rollouts # ---------------------------------------- seg_gen = traj_segment_generator(pi, env, args, timesteps_per_batch, stochastic=True) episodes_so_far = 0 timesteps_so_far = 0 iters_so_far = 0 tstart = time.time() lenbuffer = deque(maxlen=40) # rolling buffer for episode lengths rewbuffer = deque(maxlen=40) # rolling buffer for episode rewards args.logdir = "{}/thread_{}".format(args.logdir, args.thread) logger = Logger(args.logdir) while time.time() - tstart < 86400 * args.max_train_days: # logger.log("********** Iteration %i ************" % iters_so_far) meanlosses = [0] * len(loss_names) with timed("sampling"): seg = seg_gen.__next__() add_vtarg_and_adv(seg, gamma, lam) # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets)) ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[ "tdlamret"] vpredbefore = seg["vpred"] # predicted value function before udpate atarg = (atarg - atarg.mean() ) / atarg.std() # standardized advantage function estimate if hasattr(pi, "ret_rms"): pi.ret_rms.update(tdlamret) if hasattr(pi, "ob_rms"): pi.ob_rms.update(ob) # update running mean/std for policy segargs = seg["ob"], seg["ac"], seg["adv"] fvpargs = [arr[::5] for arr in segargs] def fisher_vector_product(p): return allmean(compute_fvp(p, *fvpargs)) + cg_damping * p assign_old_eq_new() # set old parameter values to new parameter values with timed("computegrad"): *lossbefore, g = compute_lossandgrad(*segargs) lossbefore = allmean(np.array(lossbefore)) g = allmean(g) if np.allclose(g, 0): pass # logger.log("Got zero gradient. not updating") else: with timed("cg"): stepdir = cg(fisher_vector_product, g, cg_iters=cg_iters, verbose=rank == 0) assert np.isfinite(stepdir).all() shs = .5 * stepdir.dot(fisher_vector_product(stepdir)) lm = np.sqrt(shs / max_kl) # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g)) fullstep = stepdir / lm expectedimprove = g.dot(fullstep) surrbefore = lossbefore[0] stepsize = 1.0 thbefore = get_flat() for _ in range(10): thnew = thbefore + fullstep * stepsize set_from_flat(thnew) meanlosses = surr, kl, *_ = allmean( np.array(compute_losses(*segargs))) improve = surr - surrbefore # logger.log("Expected: %.3f Actual: %.3f" % (expectedimprove, improve)) # if not np.isfinite(meanlosses).all(): # logger.log("Got non-finite value of losses -- bad!") # elif kl > max_kl * 1.5: # logger.log("violated KL constraint. shrinking step.") # elif improve < 0: # logger.log("surrogate didn't improve. shrinking step.") # else: # logger.log("Stepsize OK!") # break stepsize *= .5 else: # logger.log("couldn't compute a good step") set_from_flat(thbefore) if nworkers > 1 and iters_so_far % 20 == 0: paramsums = MPI.COMM_WORLD.allgather( (thnew.sum(), vfadam.getflat().sum())) # list of tuples assert all( np.allclose(ps, paramsums[0]) for ps in paramsums[1:]) with timed("vf"): for _ in range(vf_iters): for (mbob, mbret) in dataset.iterbatches( (seg["ob"], seg["tdlamret"]), include_final_partial_batch=False, batch_size=64): g = allmean(compute_vflossandgrad(mbob, mbret)) vfadam.update(g, vf_stepsize) saver.sync() lrlocal = (seg["ep_lens"], seg["ep_rets"]) # local values listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples lens, rews = map(flatten_lists, zip(*listoflrpairs)) lenbuffer.extend(lens) rewbuffer.extend(rews) episodes_so_far += len(lens) timesteps_so_far += sum(lens) iters_so_far += 1 # Logging logger.scalar_summary("episodes", len(lens), iters_so_far) for (lossname, lossval) in zip(loss_names, meanlosses): logger.scalar_summary(lossname, lossval, episodes_so_far) logger.scalar_summary("ev_tdlam_before", explained_variance(vpredbefore, tdlamret), episodes_so_far) logger.scalar_summary("step", np.mean(lenbuffer), episodes_so_far) logger.scalar_summary("reward", np.mean(rewbuffer), episodes_so_far) logger.scalar_summary("best reward", np.max(rewbuffer), episodes_so_far) elapsed_time = time.time() - tstart logger.scalar_summary("episode per minute", episodes_so_far / elapsed_time * 60, episodes_so_far) logger.scalar_summary("step per second", timesteps_so_far / elapsed_time, episodes_so_far)
def learn( env, policy_func, args, *, timesteps_per_batch, # timesteps per actor per update clip_param, entcoeff, # clipping parameter epsilon, entropy coeff optim_epochs, optim_stepsize, optim_batchsize, # optimization hypers gamma, lam, # advantage estimation adam_epsilon=1e-5, schedule='constant' ): # annealing for stepsize parameters (epsilon and adam), # Setup losses and stuff # ---------------------------------------- ob_space = env.observation_space ac_space = env.action_space pi = policy_func("pi", ob_space, ac_space) # Construct network for new policy oldpi = policy_func("oldpi", ob_space, ac_space) # Network for old policy atarg = tf.placeholder( dtype=tf.float32, shape=[None]) # Target advantage function (if applicable) ret = tf.placeholder(dtype=tf.float32, shape=[None]) # Empirical return lrmult = tf.placeholder( name='lrmult', dtype=tf.float32, shape=[]) # learning rate multiplier, updated with schedule clip_param = clip_param * lrmult # Annealed cliping parameter epislon ob = U.get_placeholder_cached(name="ob") ac = pi.pdtype.sample_placeholder([None]) kloldnew = oldpi.pd.kl(pi.pd) ent = pi.pd.entropy() meankl = U.mean(kloldnew) meanent = U.mean(ent) pol_entpen = (-entcoeff) * meanent ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac)) # pnew / pold surr1 = ratio * atarg # surrogate from conservative policy iteration surr2 = U.clip(ratio, 1.0 - clip_param, 1.0 + clip_param) * atarg # pol_surr = -U.mean(tf.minimum( surr1, surr2)) # PPO's pessimistic surrogate (L^CLIP) vf_loss = U.mean(tf.square(pi.vpred - ret)) total_loss = pol_surr + pol_entpen + vf_loss losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent] loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"] var_list = pi.get_trainable_variables() lossandgrad = U.function([ob, ac, atarg, ret, lrmult], losses + [U.flatgrad(total_loss, var_list)]) adam = MpiAdam(var_list, epsilon=adam_epsilon) policy_var_list = [ v for v in var_list if v.name.split("/")[0].startswith("pi") ] saver = MpiSaver(policy_var_list, log_prefix=args.logdir) assign_old_eq_new = U.function( [], [], updates=[ tf.assign(oldv, newv) for (oldv, newv) in zipsame(oldpi.get_variables(), pi.get_variables()) ]) compute_losses = U.function([ob, ac, atarg, ret, lrmult], losses) U.initialize() saver.restore(restore_from=args.restore_actor_from) adam.sync() # Prepare for rollouts # ---------------------------------------- seg_gen = traj_segment_generator(pi, env, args, timesteps_per_batch, stochastic=True) episodes_so_far = 0 timesteps_so_far = 0 iters_so_far = 0 tstart = time.time() lenbuffer = deque(maxlen=100) # rolling buffer for episode lengths rewbuffer = deque(maxlen=100) # rolling buffer for episode rewards # max_timesteps = 1e10 cur_lrmult = 1.0 args.logdir = "{}/thread_{}".format(args.logdir, args.thread) logger = Logger(args.logdir) while time.time() - tstart < 86400 * args.max_train_days: # if schedule == 'constant': # cur_lrmult = 1.0 # elif schedule == 'linear': # cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0) # else: # raise NotImplementedError # logger.log("********** Iteration %i ************" % iters_so_far) seg = seg_gen.__next__() add_vtarg_and_adv(seg, gamma, lam) # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets)) ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[ "tdlamret"] vpredbefore = seg["vpred"] # predicted value function before udpate atarg = (atarg - atarg.mean() ) / atarg.std() # standardized advantage function estimate d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret), shuffle=True) optim_batchsize = optim_batchsize or ob.shape[0] if hasattr(pi, "ob_rms"): pi.ob_rms.update(ob) # update running mean/std for policy assign_old_eq_new() # set old parameter values to new parameter values # logger.log("Optimizing...") # logger.log(fmt_row(13, loss_names)) # Here we do a bunch of optimization epochs over the data for _ in range(optim_epochs): losses = [ ] # list of tuples, each of which gives the loss for a minibatch for batch in d.iterate_once(optim_batchsize): *newlosses, g = lossandgrad(batch["ob"], batch["ac"], batch["atarg"], batch["vtarg"], cur_lrmult) adam.update(g, optim_stepsize * cur_lrmult) losses.append(newlosses) # logger.log(fmt_row(13, np.mean(losses, axis=0))) saver.sync() # logger.log("Evaluating losses...") losses = [] for batch in d.iterate_once(optim_batchsize): newlosses = compute_losses(batch["ob"], batch["ac"], batch["atarg"], batch["vtarg"], cur_lrmult) losses.append(newlosses) meanlosses, _, _ = mpi_moments(losses, axis=0) # logger.log(fmt_row(13, meanlosses)) lrlocal = (seg["ep_lens"], seg["ep_rets"]) # local values listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples lens, rews = map(flatten_lists, zip(*listoflrpairs)) lenbuffer.extend(lens) rewbuffer.extend(rews) episodes_so_far += len(lens) timesteps_so_far += sum(lens) iters_so_far += 1 # Logging logger.scalar_summary("episodes", len(lens), iters_so_far) for (lossname, lossval) in zip(loss_names, meanlosses): logger.scalar_summary(lossname, lossval, episodes_so_far) logger.scalar_summary("ev_tdlam_before", explained_variance(vpredbefore, tdlamret), episodes_so_far) logger.scalar_summary("step", np.mean(lenbuffer), episodes_so_far) logger.scalar_summary("reward", np.mean(rewbuffer), episodes_so_far) logger.scalar_summary("best reward", np.max(rewbuffer), episodes_so_far) elapsed_time = time.time() - tstart logger.scalar_summary("episode per minute", episodes_so_far / elapsed_time * 60, episodes_so_far) logger.scalar_summary("step per second", timesteps_so_far / elapsed_time, episodes_so_far)