def perturb_model(args, model, random_seed): """ Modifies the given model with a pertubation of its parameters, as well as the negative perturbation, and returns both perturbed models. """ new_model = ES(args.small_net) anti_model = ES(args.small_net) new_model.load_state_dict(model.state_dict()) anti_model.load_state_dict(model.state_dict()) np.random.seed(random_seed) for (k, v), (anti_k, anti_v) in zip(new_model.es_params(), anti_model.es_params()): eps = np.random.normal(0, 1, v.size()) v += torch.from_numpy(args.sigma * eps).float() anti_v += torch.from_numpy(args.sigma * -eps).float() return [new_model, anti_model]
help='Silence print statements during training') parser.add_argument('--test', action='store_true', help='Just render the env, no training') if __name__ == '__main__': args = parser.parse_args() assert args.n % 2 == 0 if args.small_net and args.env_name not in [ 'CartPole-v0', 'CartPole-v1', 'MountainCar-v0' ]: args.env_name = 'CartPole-v1' print('Switching env to CartPole') env = create_atari_env(args.env_name) chkpt_dir = 'checkpoints/%s/' % args.env_name if not os.path.exists(chkpt_dir): os.makedirs(chkpt_dir) synced_model = ES(env.observation_space.shape[0], env.action_space, args.small_net) for param in synced_model.parameters(): param.requires_grad = False if args.restore: state_dict = torch.load(args.restore) synced_model.load_state_dict(state_dict) if args.test: render_env(args, synced_model, env) else: train_loop(args, synced_model, env, chkpt_dir)
action='store_true', help='Use simple MLP on CartPole') parser.add_argument('--variable-ep-len', action='store_true', help="Change max episode length during training") parser.add_argument('--silent', action='store_true', help='Silence print statements during training') parser.add_argument('--test', action='store_true', help='Just render the env, no training') if __name__ == '__main__': args = parser.parse_args() assert args.n % 2 == 0 chkpt_dir = 'checkpoints/%s/' % args.env_name if not os.path.exists(chkpt_dir): os.makedirs(chkpt_dir) synced_model = ES(args.small_net) for param in synced_model.parameters(): param.requires_grad = False if args.restore: state_dict = torch.load(args.restore) synced_model.load_state_dict(state_dict) if args.test: render_env(args, synced_model) else: train_loop(args, synced_model, chkpt_dir)
help='Silence print statements during training') parser.add_argument('--test', action='store_true', help='Just render the env, no training') parser.add_argument('--max-gradient-updates', type=int, default=100000, metavar='MGU', help='maximum number of updates') if __name__ == '__main__': args = parser.parse_args() assert args.n % 2 == 0 chkpt_dir = 'checkpoints/' if not os.path.exists(chkpt_dir): os.makedirs(chkpt_dir) env = TicTacToeEnv() synced_model = ES(env.observation_space, env.action_space) for param in synced_model.parameters(): param.requires_grad = False if args.restore: state_dict = torch.load(args.restore) synced_model.load_state_dict(state_dict) if args.test: render_env(synced_model) else: train_loop(args, synced_model, chkpt_dir)
def perturb_model(args, model, random_seed, env): """ Modifies the given model with a perturbation of its parameters, as well as the negative perturbation, and returns both perturbed models. """ new_model = ES(env.observation_space,env.action_space, use_a3c_net=args.a3c_net, use_virtual_batch_norm=args.virtual_batch_norm) anti_model = ES(env.observation_space,env.action_space, use_a3c_net=args.a3c_net, use_virtual_batch_norm=args.virtual_batch_norm) new_model.load_state_dict(model.state_dict()) anti_model.load_state_dict(model.state_dict()) np.random.seed(random_seed) eps = args.sigma * np.random.normal(0.0, 1.0, size=model.count_parameters()) new_model.adjust_es_params(add=eps) anti_model.adjust_es_params(add=-eps) # for (k, v), (anti_k, anti_v) in zip(new_model.get_es_params(), # anti_model.get_es_params()): # eps = np.random.normal(0, 1, v.size()) # v += torch.from_numpy(args.sigma*eps).float() # anti_v += torch.from_numpy(args.sigma*-eps).float() return [new_model, anti_model]
env = create_atari_env(args.env_name, frame_stack_size=args.stack_images, noop_init=args.noop_init, image_dim=args.image_dim) # set checkpoint directory if args.checkpoint_dir: chkpt_dir = args.checkpoint_dir else: chkpt_dir = 'checkpoints/%s/' % args.env_name if not os.path.exists(chkpt_dir): os.makedirs(chkpt_dir) # instantiate model (and restore if needed) synced_model = ES(env.observation_space, env.action_space, use_a3c_net=args.a3c_net, use_virtual_batch_norm=args.virtual_batch_norm) for param in synced_model.parameters(): param.requires_grad = False if args.restore: state_dict = torch.load(args.restore) synced_model.load_state_dict(state_dict) # compute batch for virtual batch normalization if args.virtual_batch_norm and not args.test: # print('Computing batch for virtual batch normalization') virtual_batch = gather_for_virtual_batch_norm( env, batch_size=args.virtual_batch_norm) virtual_batch = torchify(virtual_batch, unsqueeze=False) else: virtual_batch = None