def main(_):
    # Make sure we have a valid config that inherits all the keys defined in the
    # base config.
    validate_config(FLAGS.config, mode="pretrain")

    config = FLAGS.config
    exp_dir = osp.join(config.root_dir, FLAGS.experiment_name)
    setup_experiment(exp_dir, config, FLAGS.resume)

    # No need to do any pretraining if we're loading the raw pretrained
    # ImageNet baseline.
    if FLAGS.raw_imagenet:
        return

    # Setup compute device.
    if torch.cuda.is_available():
        device = torch.device(FLAGS.device)
    else:
        logging.info("No GPU device found. Falling back to CPU.")
        device = torch.device("cpu")
    logging.info("Using device: %s", device)

    # Set RNG seeds.
    if config.seed is not None:
        logging.info("Pretraining experiment seed: %d", config.seed)
        experiment.seed_rngs(config.seed)
        experiment.set_cudnn(config.cudnn_deterministic,
                             config.cudnn_benchmark)
    else:
        logging.info(
            "No RNG seed has been set for this pretraining experiment.")

    logger = Logger(osp.join(exp_dir, "tb"), FLAGS.resume)

    # Load factories.
    (
        model,
        optimizer,
        pretrain_loaders,
        downstream_loaders,
        trainer,
        eval_manager,
    ) = common.get_factories(config, device)

    # Create checkpoint manager.
    checkpoint_dir = osp.join(exp_dir, "checkpoints")
    checkpoint_manager = CheckpointManager(
        checkpoint_dir,
        model=model,
        optimizer=optimizer,
    )

    global_step = checkpoint_manager.restore_or_initialize()
    total_batches = max(1, len(pretrain_loaders["train"]))
    epoch = int(global_step / total_batches)
    complete = False
    stopwatch = Stopwatch()
    try:
        while not complete:
            for batch in pretrain_loaders["train"]:
                train_loss = trainer.train_one_iter(batch)

                if not global_step % config.logging_frequency:
                    for k, v in train_loss.items():
                        logger.log_scalar(v, global_step, k, "pretrain")
                    logger.flush()

                if not global_step % config.eval.eval_frequency:
                    # Evaluate the model on the pretraining validation dataset.
                    valid_loss = trainer.eval_num_iters(
                        pretrain_loaders["valid"],
                        config.eval.val_iters,
                    )
                    for k, v in valid_loss.items():
                        logger.log_scalar(v, global_step, k, "pretrain")

                    # Evaluate the model on the downstream datasets.
                    for split, downstream_loader in downstream_loaders.items():
                        eval_to_metric = eval_manager.evaluate(
                            model,
                            downstream_loader,
                            device,
                            config.eval.val_iters,
                        )
                        for eval_name, eval_out in eval_to_metric.items():
                            eval_out.log(
                                logger,
                                global_step,
                                eval_name,
                                f"downstream/{split}",
                            )

                # Save model checkpoint.
                if not global_step % config.checkpointing_frequency:
                    checkpoint_manager.save(global_step)

                # Exit if complete.
                global_step += 1
                if global_step > config.optim.train_max_iters:
                    complete = True
                    break

                time_per_iter = stopwatch.elapsed()
                logging.info(
                    "Iter[{}/{}] (Epoch {}), {:.6f}s/iter, Loss: {:.3f}".
                    format(
                        global_step,
                        config.optim.train_max_iters,
                        epoch,
                        time_per_iter,
                        train_loss["train/total_loss"].item(),
                    ))
                stopwatch.reset()
            epoch += 1

    except KeyboardInterrupt:
        logging.info(
            "Caught keyboard interrupt. Saving model before quitting.")

    finally:
        checkpoint_manager.save(global_step)
        logger.close()
Example #2
0
def main(_):
    # Make sure we have a valid config that inherits all the keys defined in the
    # base config.
    validate_config(FLAGS.config, mode="rl")

    config = FLAGS.config
    exp_dir = osp.join(
        config.save_dir,
        FLAGS.experiment_name,
        str(FLAGS.seed),
    )
    utils.setup_experiment(exp_dir, config, FLAGS.resume)

    # Setup compute device.
    if torch.cuda.is_available():
        device = torch.device(FLAGS.device)
    else:
        logging.info("No GPU device found. Falling back to CPU.")
        device = torch.device("cpu")
    logging.info("Using device: %s", device)

    # Set RNG seeds.
    if FLAGS.seed is not None:
        logging.info("RL experiment seed: %d", FLAGS.seed)
        experiment.seed_rngs(FLAGS.seed)
        experiment.set_cudnn(config.cudnn_deterministic,
                             config.cudnn_benchmark)
    else:
        logging.info("No RNG seed has been set for this RL experiment.")

    # Load env.
    env = utils.make_env(
        FLAGS.env_name,
        FLAGS.seed,
        action_repeat=config.action_repeat,
        frame_stack=config.frame_stack,
    )
    eval_env = utils.make_env(
        FLAGS.env_name,
        FLAGS.seed + 42,
        action_repeat=config.action_repeat,
        frame_stack=config.frame_stack,
        save_dir=osp.join(exp_dir, "video", "eval"),
    )

    # Dynamically set observation and action space values.
    config.sac.obs_dim = env.observation_space.shape[0]
    config.sac.action_dim = env.action_space.shape[0]
    config.sac.action_range = [
        float(env.action_space.low.min()),
        float(env.action_space.high.max()),
    ]

    # Resave the config since the dynamic values have been updated at this point
    # and make it immutable for safety :)
    utils.dump_config(exp_dir, config)
    config = config_dict.FrozenConfigDict(config)

    policy = agent.SAC(device, config.sac)

    buffer = utils.make_buffer(env, device, config)

    # Create checkpoint manager.
    checkpoint_dir = osp.join(exp_dir, "checkpoints")
    checkpoint_manager = CheckpointManager(
        checkpoint_dir,
        policy=policy,
        **policy.optim_dict(),
    )

    logger = Logger(osp.join(exp_dir, "tb"), FLAGS.resume)

    try:
        start = checkpoint_manager.restore_or_initialize()
        observation, done = env.reset(), False
        for i in tqdm(range(start, config.num_train_steps), initial=start):
            if i < config.num_seed_steps:
                action = env.action_space.sample()
            else:
                policy.eval()
                action = policy.act(observation, sample=True)
            next_observation, reward, done, info = env.step(action)

            if not done or "TimeLimit.truncated" in info:
                mask = 1.0
            else:
                mask = 0.0

            if not config.reward_wrapper.pretrained_path:
                buffer.insert(observation, action, reward, next_observation,
                              mask)
            else:
                buffer.insert(
                    observation,
                    action,
                    reward,
                    next_observation,
                    mask,
                    env.render(mode="rgb_array"),
                )
            observation = next_observation

            if done:
                observation, done = env.reset(), False
                for k, v in info["episode"].items():
                    logger.log_scalar(v, info["total"]["timesteps"], k,
                                      "training")

            if i >= config.num_seed_steps:
                policy.train()
                train_info = policy.update(buffer, i)

                if (i + 1) % config.log_frequency == 0:
                    for k, v in train_info.items():
                        logger.log_scalar(v, info["total"]["timesteps"], k,
                                          "training")
                    logger.flush()

            if (i + 1) % config.eval_frequency == 0:
                eval_stats = evaluate(policy, eval_env,
                                      config.num_eval_episodes)
                for k, v in eval_stats.items():
                    logger.log_scalar(
                        v,
                        info["total"]["timesteps"],
                        f"average_{k}s",
                        "evaluation",
                    )
                logger.flush()

            if (i + 1) % config.checkpoint_frequency == 0:
                checkpoint_manager.save(i)

    except KeyboardInterrupt:
        print("Caught keyboard interrupt. Saving before quitting.")

    finally:
        checkpoint_manager.save(i)  # pylint: disable=undefined-loop-variable
        logger.close()