def load_model_checkpoint(pretrained_path, device): """Load a pretrained model and optionally a precomputed goal embedding.""" config = load_config_from_dir(pretrained_path) model = common.get_model(config) model.to(device).eval() checkpoint_dir = os.path.join(pretrained_path, "checkpoints") checkpoint_manager = CheckpointManager(checkpoint_dir, model=model) global_step = checkpoint_manager.restore_or_initialize() logging.info("Restored model from checkpoint %d.", global_step) return config, model
def setup(): """Load the latest embedder checkpoint and dataloaders.""" config = utils.load_config_from_dir(FLAGS.experiment_path) model = common.get_model(config) downstream_loaders = common.get_downstream_dataloaders(config, False)["train"] checkpoint_dir = os.path.join(FLAGS.experiment_path, "checkpoints") if FLAGS.restore_checkpoint: checkpoint_manager = CheckpointManager(checkpoint_dir, model=model) global_step = checkpoint_manager.restore_or_initialize() logging.info("Restored model from checkpoint %d.", global_step) else: logging.info("Skipping checkpoint restore.") return model, downstream_loaders
def main(_): # Make sure we have a valid config that inherits all the keys defined in the # base config. validate_config(FLAGS.config, mode="pretrain") config = FLAGS.config exp_dir = osp.join(config.root_dir, FLAGS.experiment_name) setup_experiment(exp_dir, config, FLAGS.resume) # No need to do any pretraining if we're loading the raw pretrained # ImageNet baseline. if FLAGS.raw_imagenet: return # Setup compute device. if torch.cuda.is_available(): device = torch.device(FLAGS.device) else: logging.info("No GPU device found. Falling back to CPU.") device = torch.device("cpu") logging.info("Using device: %s", device) # Set RNG seeds. if config.seed is not None: logging.info("Pretraining experiment seed: %d", config.seed) experiment.seed_rngs(config.seed) experiment.set_cudnn(config.cudnn_deterministic, config.cudnn_benchmark) else: logging.info( "No RNG seed has been set for this pretraining experiment.") logger = Logger(osp.join(exp_dir, "tb"), FLAGS.resume) # Load factories. ( model, optimizer, pretrain_loaders, downstream_loaders, trainer, eval_manager, ) = common.get_factories(config, device) # Create checkpoint manager. checkpoint_dir = osp.join(exp_dir, "checkpoints") checkpoint_manager = CheckpointManager( checkpoint_dir, model=model, optimizer=optimizer, ) global_step = checkpoint_manager.restore_or_initialize() total_batches = max(1, len(pretrain_loaders["train"])) epoch = int(global_step / total_batches) complete = False stopwatch = Stopwatch() try: while not complete: for batch in pretrain_loaders["train"]: train_loss = trainer.train_one_iter(batch) if not global_step % config.logging_frequency: for k, v in train_loss.items(): logger.log_scalar(v, global_step, k, "pretrain") logger.flush() if not global_step % config.eval.eval_frequency: # Evaluate the model on the pretraining validation dataset. valid_loss = trainer.eval_num_iters( pretrain_loaders["valid"], config.eval.val_iters, ) for k, v in valid_loss.items(): logger.log_scalar(v, global_step, k, "pretrain") # Evaluate the model on the downstream datasets. for split, downstream_loader in downstream_loaders.items(): eval_to_metric = eval_manager.evaluate( model, downstream_loader, device, config.eval.val_iters, ) for eval_name, eval_out in eval_to_metric.items(): eval_out.log( logger, global_step, eval_name, f"downstream/{split}", ) # Save model checkpoint. if not global_step % config.checkpointing_frequency: checkpoint_manager.save(global_step) # Exit if complete. global_step += 1 if global_step > config.optim.train_max_iters: complete = True break time_per_iter = stopwatch.elapsed() logging.info( "Iter[{}/{}] (Epoch {}), {:.6f}s/iter, Loss: {:.3f}". format( global_step, config.optim.train_max_iters, epoch, time_per_iter, train_loss["train/total_loss"].item(), )) stopwatch.reset() epoch += 1 except KeyboardInterrupt: logging.info( "Caught keyboard interrupt. Saving model before quitting.") finally: checkpoint_manager.save(global_step) logger.close()
def main(_): # Make sure we have a valid config that inherits all the keys defined in the # base config. validate_config(FLAGS.config, mode="rl") config = FLAGS.config exp_dir = osp.join( config.save_dir, FLAGS.experiment_name, str(FLAGS.seed), ) utils.setup_experiment(exp_dir, config, FLAGS.resume) # Setup compute device. if torch.cuda.is_available(): device = torch.device(FLAGS.device) else: logging.info("No GPU device found. Falling back to CPU.") device = torch.device("cpu") logging.info("Using device: %s", device) # Set RNG seeds. if FLAGS.seed is not None: logging.info("RL experiment seed: %d", FLAGS.seed) experiment.seed_rngs(FLAGS.seed) experiment.set_cudnn(config.cudnn_deterministic, config.cudnn_benchmark) else: logging.info("No RNG seed has been set for this RL experiment.") # Load env. env = utils.make_env( FLAGS.env_name, FLAGS.seed, action_repeat=config.action_repeat, frame_stack=config.frame_stack, ) eval_env = utils.make_env( FLAGS.env_name, FLAGS.seed + 42, action_repeat=config.action_repeat, frame_stack=config.frame_stack, save_dir=osp.join(exp_dir, "video", "eval"), ) # Dynamically set observation and action space values. config.sac.obs_dim = env.observation_space.shape[0] config.sac.action_dim = env.action_space.shape[0] config.sac.action_range = [ float(env.action_space.low.min()), float(env.action_space.high.max()), ] # Resave the config since the dynamic values have been updated at this point # and make it immutable for safety :) utils.dump_config(exp_dir, config) config = config_dict.FrozenConfigDict(config) policy = agent.SAC(device, config.sac) buffer = utils.make_buffer(env, device, config) # Create checkpoint manager. checkpoint_dir = osp.join(exp_dir, "checkpoints") checkpoint_manager = CheckpointManager( checkpoint_dir, policy=policy, **policy.optim_dict(), ) logger = Logger(osp.join(exp_dir, "tb"), FLAGS.resume) try: start = checkpoint_manager.restore_or_initialize() observation, done = env.reset(), False for i in tqdm(range(start, config.num_train_steps), initial=start): if i < config.num_seed_steps: action = env.action_space.sample() else: policy.eval() action = policy.act(observation, sample=True) next_observation, reward, done, info = env.step(action) if not done or "TimeLimit.truncated" in info: mask = 1.0 else: mask = 0.0 if not config.reward_wrapper.pretrained_path: buffer.insert(observation, action, reward, next_observation, mask) else: buffer.insert( observation, action, reward, next_observation, mask, env.render(mode="rgb_array"), ) observation = next_observation if done: observation, done = env.reset(), False for k, v in info["episode"].items(): logger.log_scalar(v, info["total"]["timesteps"], k, "training") if i >= config.num_seed_steps: policy.train() train_info = policy.update(buffer, i) if (i + 1) % config.log_frequency == 0: for k, v in train_info.items(): logger.log_scalar(v, info["total"]["timesteps"], k, "training") logger.flush() if (i + 1) % config.eval_frequency == 0: eval_stats = evaluate(policy, eval_env, config.num_eval_episodes) for k, v in eval_stats.items(): logger.log_scalar( v, info["total"]["timesteps"], f"average_{k}s", "evaluation", ) logger.flush() if (i + 1) % config.checkpoint_frequency == 0: checkpoint_manager.save(i) except KeyboardInterrupt: print("Caught keyboard interrupt. Saving before quitting.") finally: checkpoint_manager.save(i) # pylint: disable=undefined-loop-variable logger.close()