Ejemplo n.º 1
0
    def training_pipeline(self, **kwargs):
        training_steps = int(300000000)
        tf_steps = int(5e6)
        anneal_steps = int(5e6)
        il_no_tf_steps = training_steps - tf_steps - anneal_steps
        assert il_no_tf_steps > 0

        lr = 3e-4
        num_mini_batch = 2 if torch.cuda.is_available() else 1
        update_repeats = 4
        num_steps = 30
        save_interval = 5000000
        log_interval = 10000 if torch.cuda.is_available() else 1
        gamma = 0.99
        use_gae = True
        gae_lambda = 0.95
        max_grad_norm = 0.5
        return TrainingPipeline(
            save_interval=save_interval,
            metric_accumulate_interval=log_interval,
            optimizer_builder=Builder(optim.Adam, dict(lr=lr)),
            num_mini_batch=num_mini_batch,
            update_repeats=update_repeats,
            max_grad_norm=max_grad_norm,
            num_steps=num_steps,
            named_losses={
                "imitation_loss": Imitation(),
            },
            gamma=gamma,
            use_gae=use_gae,
            gae_lambda=gae_lambda,
            advance_scene_rollout_period=self.ADVANCE_SCENE_ROLLOUT_PERIOD,
            pipeline_stages=[
                PipelineStage(
                    loss_names=["imitation_loss"],
                    max_stage_steps=tf_steps,
                    teacher_forcing=LinearDecay(
                        startp=1.0,
                        endp=1.0,
                        steps=tf_steps,
                    ),
                ),
                PipelineStage(
                    loss_names=["imitation_loss"],
                    max_stage_steps=anneal_steps + il_no_tf_steps,
                    teacher_forcing=LinearDecay(
                        startp=1.0,
                        endp=0.0,
                        steps=anneal_steps,
                    ),
                ),
            ],
            lr_scheduler_builder=Builder(
                LambdaLR,
                {"lr_lambda": LinearDecay(steps=training_steps)},
            ),
        )
Ejemplo n.º 2
0
    def training_pipeline(cls, **kwargs):
        dagger_steos = int(1e4)
        ppo_steps = int(1e6)
        lr = 2.5e-4
        num_mini_batch = 2 if not torch.cuda.is_available() else 6
        update_repeats = 4
        num_steps = 128
        metric_accumulate_interval = cls.MAX_STEPS * 10  # Log every 10 max length tasks
        save_interval = 10000
        gamma = 0.99
        use_gae = True
        gae_lambda = 1.0
        max_grad_norm = 0.5

        return TrainingPipeline(
            save_interval=save_interval,
            metric_accumulate_interval=metric_accumulate_interval,
            optimizer_builder=Builder(optim.Adam, dict(lr=lr)),
            num_mini_batch=num_mini_batch,
            update_repeats=update_repeats,
            max_grad_norm=max_grad_norm,
            num_steps=num_steps,
            named_losses={
                "ppo_loss": PPO(clip_decay=LinearDecay(ppo_steps),
                                **PPOConfig),
                "imitation_loss": Imitation(),  # We add an imitation loss.
            },
            gamma=gamma,
            use_gae=use_gae,
            gae_lambda=gae_lambda,
            advance_scene_rollout_period=cls.ADVANCE_SCENE_ROLLOUT_PERIOD,
            pipeline_stages=[
                PipelineStage(
                    loss_names=["imitation_loss"],
                    teacher_forcing=LinearDecay(
                        startp=1.0,
                        endp=0.0,
                        steps=dagger_steos,
                    ),
                    max_stage_steps=dagger_steos,
                ),
                PipelineStage(
                    loss_names=["ppo_loss"],
                    max_stage_steps=ppo_steps,
                ),
            ],
            lr_scheduler_builder=Builder(
                LambdaLR, {"lr_lambda": LinearDecay(steps=ppo_steps)}),
        )
    def _training_pipeline_info(cls) -> Dict[str, Any]:
        """Define how the model trains."""

        training_steps = cls.TRAINING_STEPS
        il_params = cls._use_label_to_get_training_params()
        bc_tf1_steps = il_params["bc_tf1_steps"]
        dagger_steps = il_params["dagger_steps"]

        return dict(
            named_losses=dict(
                walkthrough_ppo_loss=MaskedPPO(
                    mask_uuid="in_walkthrough_phase",
                    ppo_params=dict(
                        clip_decay=LinearDecay(training_steps), **PPOConfig
                    ),
                ),
                imitation_loss=Imitation(),
            ),
            pipeline_stages=[
                PipelineStage(
                    loss_names=["walkthrough_ppo_loss", "imitation_loss"],
                    max_stage_steps=training_steps,
                    teacher_forcing=StepwiseLinearDecay(
                        cumm_steps_and_values=[
                            (bc_tf1_steps, 1.0),
                            (bc_tf1_steps + dagger_steps, 0.0),
                        ]
                    ),
                )
            ],
            **il_params,
        )
 def training_pipeline(cls, **kwargs):
     ppo_steps = int(250000000)
     lr = 3e-4
     num_mini_batch = 1
     update_repeats = 3
     num_steps = 30
     save_interval = 5000000
     log_interval = 1000
     gamma = 0.99
     use_gae = True
     gae_lambda = 0.95
     max_grad_norm = 0.5
     return TrainingPipeline(
         save_interval=save_interval,
         metric_accumulate_interval=log_interval,
         optimizer_builder=Builder(optim.Adam, dict(lr=lr)),
         num_mini_batch=num_mini_batch,
         update_repeats=update_repeats,
         max_grad_norm=max_grad_norm,
         num_steps=num_steps,
         named_losses={"ppo_loss": PPO(**PPOConfig)},
         gamma=gamma,
         use_gae=use_gae,
         gae_lambda=gae_lambda,
         advance_scene_rollout_period=cls.ADVANCE_SCENE_ROLLOUT_PERIOD,
         pipeline_stages=[
             PipelineStage(loss_names=["ppo_loss"],
                           max_stage_steps=ppo_steps)
         ],
         lr_scheduler_builder=Builder(
             LambdaLR, {"lr_lambda": LinearDecay(steps=ppo_steps)}),
     )
Ejemplo n.º 5
0
 def training_pipeline(cls, **kwargs) -> TrainingPipeline:
     ppo_steps = int(150000)
     return TrainingPipeline(
         named_losses=dict(
             imitation_loss=Imitation(
                 cls.SENSORS[1]
             ),  # 0 is Minigrid, 1 is ExpertActionSensor
             ppo_loss=PPO(**PPOConfig, entropy_method_name="conditional_entropy"),
         ),  # type:ignore
         pipeline_stages=[
             PipelineStage(
                 teacher_forcing=LinearDecay(
                     startp=1.0, endp=0.0, steps=ppo_steps // 2,
                 ),
                 loss_names=["imitation_loss", "ppo_loss"],
                 max_stage_steps=ppo_steps,
             )
         ],
         optimizer_builder=Builder(cast(optim.Optimizer, optim.Adam), dict(lr=1e-4)),
         num_mini_batch=4,
         update_repeats=3,
         max_grad_norm=0.5,
         num_steps=16,
         gamma=0.99,
         use_gae=True,
         gae_lambda=0.95,
         advance_scene_rollout_period=None,
         save_interval=10000,
         metric_accumulate_interval=1,
         lr_scheduler_builder=Builder(
             LambdaLR, {"lr_lambda": LinearDecay(steps=ppo_steps)}  # type:ignore
         ),
     )
Ejemplo n.º 6
0
 def training_pipeline(cls, **kwargs) -> TrainingPipeline:
     ppo_steps = int(1.2e6)
     return TrainingPipeline(
         named_losses=dict(ppo_loss=PPO(
             clip_param=0.2,
             value_loss_coef=0.5,
             entropy_coef=0.0,
         ), ),  # type:ignore
         pipeline_stages=[
             PipelineStage(loss_names=["ppo_loss"],
                           max_stage_steps=ppo_steps),
         ],
         optimizer_builder=Builder(cast(optim.Optimizer, optim.Adam),
                                   dict(lr=1e-3)),
         num_mini_batch=1,
         update_repeats=80,
         max_grad_norm=100,
         num_steps=2000,
         gamma=0.99,
         use_gae=False,
         gae_lambda=0.95,
         advance_scene_rollout_period=None,
         save_interval=200000,
         metric_accumulate_interval=50000,
         lr_scheduler_builder=Builder(
             LambdaLR,
             {"lr_lambda": LinearDecay(steps=ppo_steps)},  # type:ignore
         ),
     )
 def training_pipeline(cls, **kwargs) -> TrainingPipeline:
     ppo_steps = int(150000)
     return TrainingPipeline(
         named_losses=dict(ppo_loss=PPO(**PPOConfig)),  # type:ignore
         pipeline_stages=[
             PipelineStage(loss_names=["ppo_loss"],
                           max_stage_steps=ppo_steps)
         ],
         optimizer_builder=Builder(cast(optim.Optimizer, optim.Adam),
                                   dict(lr=1e-4)),
         num_mini_batch=4,
         update_repeats=3,
         max_grad_norm=0.5,
         num_steps=16,
         gamma=0.99,
         use_gae=True,
         gae_lambda=0.95,
         advance_scene_rollout_period=None,
         save_interval=10000,
         metric_accumulate_interval=1,
         lr_scheduler_builder=Builder(
             LambdaLR,
             {"lr_lambda": LinearDecay(steps=ppo_steps)}  # type:ignore
         ),
     )
    def _training_pipeline_info(cls, **kwargs) -> Dict[str, Any]:
        """Define how the model trains."""

        training_steps = cls.TRAINING_STEPS
        return dict(
            named_losses=dict(
                ppo_loss=PPO(clip_decay=LinearDecay(training_steps), **PPOConfig),
                binned_map_loss=BinnedPointCloudMapLoss(
                    binned_pc_uuid="binned_pc_map",
                    map_logits_uuid="ego_height_binned_map_logits",
                ),
                semantic_map_loss=SemanticMapFocalLoss(
                    semantic_map_uuid="semantic_map",
                    map_logits_uuid="ego_semantic_map_logits",
                ),
            ),
            pipeline_stages=[
                PipelineStage(
                    loss_names=["ppo_loss", "binned_map_loss", "semantic_map_loss"],
                    loss_weights=[1.0, 1.0, 100.0],
                    max_stage_steps=training_steps,
                )
            ],
            num_steps=32,
            num_mini_batch=1,
            update_repeats=3,
            use_lr_decay=True,
            lr=3e-4,
        )
Ejemplo n.º 9
0
    def training_pipeline(cls, **kwargs):
        total_train_steps = cls.TOTAL_IL_TRAIN_STEPS

        ppo_info = cls.rl_loss_default("ppo", steps=-1)
        imitation_info = cls.rl_loss_default("imitation")

        return cls._training_pipeline(
            named_losses={
                "imitation_loss": imitation_info["loss"],
            },
            pipeline_stages=[
                PipelineStage(
                    loss_names=["imitation_loss"],
                    teacher_forcing=LinearDecay(
                        startp=1.0,
                        endp=1.0,
                        steps=total_train_steps,
                    ),
                    max_stage_steps=total_train_steps,
                ),
            ],
            num_mini_batch=min(info["num_mini_batch"]
                               for info in [ppo_info, imitation_info]),
            update_repeats=min(info["update_repeats"]
                               for info in [ppo_info, imitation_info]),
            total_train_steps=total_train_steps,
        )
    def training_pipeline(cls, **kwargs):
        ppo_steps = int(75000000)
        lr = 3e-4
        num_mini_batch = 1
        update_repeats = 4
        num_steps = 128
        save_interval = 5000000
        log_interval = 10000 if torch.cuda.is_available() else 1
        gamma = 0.99
        use_gae = True
        gae_lambda = 0.95
        max_grad_norm = 0.5

        action_strs = PointNavTask.class_action_names()
        non_end_action_inds_set = {
            i
            for i, a in enumerate(action_strs) if a != robothor_constants.END
        }
        end_action_ind_set = {action_strs.index(robothor_constants.END)}

        return TrainingPipeline(
            save_interval=save_interval,
            metric_accumulate_interval=log_interval,
            optimizer_builder=Builder(optim.Adam, dict(lr=lr)),
            num_mini_batch=num_mini_batch,
            update_repeats=update_repeats,
            max_grad_norm=max_grad_norm,
            num_steps=num_steps,
            named_losses={
                "ppo_loss":
                PPO(**PPOConfig),
                "grouped_action_imitation":
                GroupedActionImitation(
                    nactions=len(PointNavTask.class_action_names()),
                    action_groups=[
                        non_end_action_inds_set, end_action_ind_set
                    ],
                ),
            },
            gamma=gamma,
            use_gae=use_gae,
            gae_lambda=gae_lambda,
            advance_scene_rollout_period=cls.ADVANCE_SCENE_ROLLOUT_PERIOD,
            pipeline_stages=[
                PipelineStage(
                    loss_names=["ppo_loss", "grouped_action_imitation"],
                    max_stage_steps=ppo_steps,
                )
            ],
            lr_scheduler_builder=Builder(
                LambdaLR, {"lr_lambda": LinearDecay(steps=ppo_steps)}),
        )
Ejemplo n.º 11
0
    def training_pipeline(cls, **kwargs):
        total_train_steps = cls.TOTAL_IL_TRAIN_STEPS
        ppo_info = cls.rl_loss_default("ppo", steps=-1)

        num_mini_batch = ppo_info["num_mini_batch"]
        update_repeats = ppo_info["update_repeats"]

        # fmt: off
        return cls._training_pipeline(
            named_losses={
                "offpolicy_expert_ce_loss":
                MiniGridOffPolicyExpertCELoss(
                    total_episodes_in_epoch=int(1e6)),
            },
            pipeline_stages=[
                # Single stage, only with off-policy training
                PipelineStage(
                    loss_names=[],  # no on-policy losses
                    max_stage_steps=
                    total_train_steps,  # keep sampling episodes in the stage
                    # Enable off-policy training:
                    offpolicy_component=OffPolicyPipelineComponent(
                        # Pass a method to instantiate data iterators
                        data_iterator_builder=lambda **extra_kwargs:
                        create_minigrid_offpolicy_data_iterator(
                            path=os.path.join(
                                BABYAI_EXPERT_TRAJECTORIES_DIR,
                                "BabyAI-GoToLocal-v0{}.pkl".
                                format("" if torch.cuda.is_available() else
                                       "-small"),
                            ),
                            nrollouts=cls.NUM_TRAIN_SAMPLERS //
                            num_mini_batch,  # per trainer batch size
                            rollout_len=cls.ROLLOUT_STEPS,
                            instr_len=cls.INSTR_LEN,
                            **extra_kwargs,
                        ),
                        loss_names=["offpolicy_expert_ce_loss"
                                    ],  # off-policy losses
                        updates=num_mini_batch *
                        update_repeats,  # number of batches per rollout
                    ),
                ),
            ],
            # As we don't have any on-policy losses, we set the next
            # two values to zero to ensure we don't attempt to
            # compute gradients for on-policy rollouts:
            num_mini_batch=0,
            update_repeats=0,
            total_train_steps=total_train_steps,
        )
Ejemplo n.º 12
0
    def training_pipeline(cls, **kwargs):
        total_training_steps = cls.TOTAL_RL_TRAIN_STEPS
        a2c_info = cls.rl_loss_default("a2c", steps=total_training_steps)

        return cls._training_pipeline(
            named_losses={"a2c_loss": a2c_info["loss"],},
            pipeline_stages=[
                PipelineStage(
                    loss_names=["a2c_loss"], max_stage_steps=total_training_steps,
                ),
            ],
            num_mini_batch=a2c_info["num_mini_batch"],
            update_repeats=a2c_info["update_repeats"],
            total_train_steps=total_training_steps,
        )
Ejemplo n.º 13
0
 def training_pipeline(cls, **kwargs):
     ppo_steps = int(10000000)
     lr = 3e-4
     num_mini_batch = 1
     update_repeats = 3
     num_steps = 30
     save_interval = 1000000
     log_interval = 100
     gamma = 0.99
     use_gae = True
     gae_lambda = 0.95
     max_grad_norm = 0.5
     return TrainingPipeline(
         save_interval=save_interval,
         metric_accumulate_interval=log_interval,
         optimizer_builder=Builder(optim.Adam, dict(lr=lr)),
         num_mini_batch=num_mini_batch,
         update_repeats=update_repeats,
         max_grad_norm=max_grad_norm,
         num_steps=num_steps,
         named_losses={
             "ppo_loss":
             PPO(**PPOConfig),
             "nie_loss":
             NIE_Reg(
                 agent_pose_uuid="agent_pose_global",
                 pose_uuid="object_pose_global",
                 local_keypoints_uuid="3Dkeypoints_local",
                 global_keypoints_uuid="3Dkeypoints_global",
                 obj_update_mask_uuid="object_update_mask",
                 obj_action_mask_uuid="object_action_mask",
             ),
             "yn_im_loss":
             YesNoImitation(yes_action_index=ObjectPlacementTask.
                            class_action_names().index(END)),
         },
         gamma=gamma,
         use_gae=use_gae,
         gae_lambda=gae_lambda,
         advance_scene_rollout_period=cls.ADVANCE_SCENE_ROLLOUT_PERIOD,
         pipeline_stages=[
             PipelineStage(
                 loss_names=["ppo_loss", "nie_loss", "yn_im_loss"],
                 max_stage_steps=ppo_steps)
         ],
         lr_scheduler_builder=Builder(
             LambdaLR, {"lr_lambda": LinearDecay(steps=ppo_steps)}),
     )
    def _training_pipeline_info(cls, **kwargs) -> Dict[str, Any]:
        """Define how the model trains."""

        training_steps = cls.TRAINING_STEPS
        return dict(
            named_losses=dict(
                ppo_loss=PPO(clip_decay=LinearDecay(training_steps), **PPOConfig)
            ),
            pipeline_stages=[
                PipelineStage(loss_names=["ppo_loss"], max_stage_steps=training_steps,)
            ],
            num_steps=64,
            num_mini_batch=1,
            update_repeats=3,
            use_lr_decay=True,
            lr=3e-4,
        )
Ejemplo n.º 15
0
 def training_pipeline(cls, **kwargs):
     total_train_steps = cls.TOTAL_IL_TRAIN_STEPS
     loss_info = cls.rl_loss_default("imitation")
     return cls._training_pipeline(
         named_losses={"imitation_loss": loss_info["loss"]},
         pipeline_stages=[
             PipelineStage(
                 loss_names=["imitation_loss"],
                 teacher_forcing=LinearDecay(
                     startp=1.0, endp=0.0, steps=total_train_steps // 2,
                 ),
                 max_stage_steps=total_train_steps,
             )
         ],
         num_mini_batch=loss_info["num_mini_batch"],
         update_repeats=loss_info["update_repeats"],
         total_train_steps=total_train_steps,
     )
Ejemplo n.º 16
0
 def training_pipeline(cls, **kwargs) -> TrainingPipeline:
     lr = 1e-4
     ppo_steps = int(8e7)  # convergence may be after 1e8
     clip_param = 0.1
     value_loss_coef = 0.5
     entropy_coef = 0.0
     num_mini_batch = 4  # optimal 64
     update_repeats = 10
     max_grad_norm = 0.5
     num_steps = 2048
     gamma = 0.99
     use_gae = True
     gae_lambda = 0.95
     advance_scene_rollout_period = None
     save_interval = 200000
     metric_accumulate_interval = 50000
     return TrainingPipeline(
         named_losses=dict(ppo_loss=PPO(
             clip_param=clip_param,
             value_loss_coef=value_loss_coef,
             entropy_coef=entropy_coef,
         ), ),  # type:ignore
         pipeline_stages=[
             PipelineStage(loss_names=["ppo_loss"],
                           max_stage_steps=ppo_steps),
         ],
         optimizer_builder=Builder(cast(optim.Optimizer, optim.Adam),
                                   dict(lr=lr)),
         num_mini_batch=num_mini_batch,
         update_repeats=update_repeats,
         max_grad_norm=max_grad_norm,
         num_steps=num_steps,
         gamma=gamma,
         use_gae=use_gae,
         gae_lambda=gae_lambda,
         advance_scene_rollout_period=advance_scene_rollout_period,
         save_interval=save_interval,
         metric_accumulate_interval=metric_accumulate_interval,
         lr_scheduler_builder=Builder(
             LambdaLR,
             {"lr_lambda": LinearDecay(steps=ppo_steps, startp=1, endp=1)
              },  # constant learning rate
         ),
     )
Ejemplo n.º 17
0
    def training_pipeline(self, **kwargs):
        # PPO
        ppo_steps = int(75000000)
        lr = 3e-4
        num_mini_batch = 1
        update_repeats = 4
        num_steps = 128
        save_interval = 5000000
        log_interval = 10000 if torch.cuda.is_available() else 1
        gamma = 0.99
        use_gae = True
        gae_lambda = 0.95
        max_grad_norm = 0.5
        PPOConfig["normalize_advantage"] = self.NORMALIZE_ADVANTAGE

        named_losses = {"ppo_loss": (PPO(**PPOConfig), 1.0)}
        named_losses = self._update_with_auxiliary_losses(named_losses)

        return TrainingPipeline(
            save_interval=save_interval,
            metric_accumulate_interval=log_interval,
            optimizer_builder=Builder(optim.Adam, dict(lr=lr)),
            num_mini_batch=num_mini_batch,
            update_repeats=update_repeats,
            max_grad_norm=max_grad_norm,
            num_steps=num_steps,
            named_losses={key: val[0]
                          for key, val in named_losses.items()},
            gamma=gamma,
            use_gae=use_gae,
            gae_lambda=gae_lambda,
            advance_scene_rollout_period=self.ADVANCE_SCENE_ROLLOUT_PERIOD,
            pipeline_stages=[
                PipelineStage(
                    loss_names=list(named_losses.keys()),
                    max_stage_steps=ppo_steps,
                    loss_weights=[val[1] for val in named_losses.values()],
                )
            ],
            lr_scheduler_builder=Builder(
                LambdaLR, {"lr_lambda": LinearDecay(steps=ppo_steps)}),
        )
    def training_pipeline(cls, **kwargs):
        total_train_steps = cls.TOTAL_IL_TRAIN_STEPS
        ppo_info = cls.rl_loss_default("ppo", steps=-1)

        num_mini_batch = ppo_info["num_mini_batch"]
        update_repeats = ppo_info["update_repeats"]

        return cls._training_pipeline(
            named_losses={
                "offpolicy_expert_ce_loss": MiniGridOffPolicyExpertCELoss(
                    total_episodes_in_epoch=int(1e6)
                    // len(cls.machine_params("train")["gpu_ids"])
                ),
            },
            pipeline_stages=[
                PipelineStage(
                    loss_names=[],
                    max_stage_steps=total_train_steps,
                    offpolicy_component=OffPolicyPipelineComponent(
                        data_iterator_builder=lambda **extra_kwargs: create_minigrid_offpolicy_data_iterator(
                            path=os.path.join(
                                BABYAI_EXPERT_TRAJECTORIES_DIR,
                                "BabyAI-GoToLocal-v0{}.pkl".format(
                                    "" if torch.cuda.is_available() else "-small"
                                ),
                            ),
                            nrollouts=cls.NUM_TRAIN_SAMPLERS // num_mini_batch,
                            rollout_len=cls.ROLLOUT_STEPS,
                            instr_len=cls.INSTR_LEN,
                            **extra_kwargs,
                        ),
                        data_iterator_kwargs_generator=cls.expert_ce_loss_kwargs_generator,
                        loss_names=["offpolicy_expert_ce_loss"],
                        updates=num_mini_batch * update_repeats,
                    ),
                ),
            ],
            num_mini_batch=0,
            update_repeats=0,
            total_train_steps=total_train_steps,
        )
Ejemplo n.º 19
0
    def _training_pipeline_info(cls, **kwargs) -> Dict[str, Any]:
        """Define how the model trains."""

        training_steps = cls.TRAINING_STEPS
        params = cls._use_label_to_get_training_params()
        bc_tf1_steps = params["bc_tf1_steps"]
        dagger_steps = params["dagger_steps"]

        return dict(
            named_losses=dict(imitation_loss=Imitation()),
            pipeline_stages=[
                PipelineStage(
                    loss_names=["imitation_loss"],
                    max_stage_steps=training_steps,
                    teacher_forcing=StepwiseLinearDecay(
                        cumm_steps_and_values=[
                            (bc_tf1_steps, 1.0),
                            (bc_tf1_steps + dagger_steps, 0.0),
                        ]
                    ),
                )
            ],
            **params
        )
    def training_pipeline(self, **kwargs):
        # These params are identical to the baseline configuration for 60 samplers (1 machine)
        ppo_steps = int(300e6)
        lr = 3e-4
        num_mini_batch = 1
        update_repeats = 4
        num_steps = 128
        save_interval = 5000000
        log_interval = 10000 if torch.cuda.is_available() else 1
        gamma = 0.99
        use_gae = True
        gae_lambda = 0.95
        max_grad_norm = 0.5

        # We add 30 million steps for small batch learning
        small_batch_steps = int(30e6)
        # And a short transition phase towards large learning rate
        # (see comment in the `lr_scheduler` helper method
        transition_steps = int(2 / 3 * self.distributed_nodes * 1e6)

        # Find exact number of samplers per GPU
        assert (self.num_train_processes % len(self.train_gpu_ids) == 0
                ), "Expected uniform number of samplers per GPU"
        samplers_per_gpu = self.num_train_processes // len(self.train_gpu_ids)

        # Multiply num_mini_batch by the largest divisor of
        # samplers_per_gpu to keep all batches of same size:
        num_mini_batch_multiplier = [
            i for i in reversed(
                range(1,
                      min(samplers_per_gpu // 2, self.distributed_nodes) + 1))
            if samplers_per_gpu % i == 0
        ][0]

        # Multiply update_repeats so that the product of this factor and
        # num_mini_batch_multiplier is >= self.distributed_nodes:
        update_repeats_multiplier = int(
            math.ceil(self.distributed_nodes / num_mini_batch_multiplier))

        return TrainingPipeline(
            save_interval=save_interval,
            metric_accumulate_interval=log_interval,
            optimizer_builder=Builder(optim.Adam, dict(lr=lr)),
            num_mini_batch=num_mini_batch,
            update_repeats=update_repeats,
            max_grad_norm=max_grad_norm,
            num_steps=num_steps,
            named_losses={"ppo_loss": PPO(**PPOConfig, show_ratios=False)},
            gamma=gamma,
            use_gae=use_gae,
            gae_lambda=gae_lambda,
            advance_scene_rollout_period=self.ADVANCE_SCENE_ROLLOUT_PERIOD,
            pipeline_stages=[
                # We increase the number of batches for the first stage to reach an
                # equivalent number of updates per collected rollout data as in the
                # 1 node/60 samplers setting
                PipelineStage(
                    loss_names=["ppo_loss"],
                    max_stage_steps=small_batch_steps,
                    num_mini_batch=num_mini_batch * num_mini_batch_multiplier,
                    update_repeats=update_repeats * update_repeats_multiplier,
                ),
                # The we proceed with the base configuration (leading to larger
                # batches due to the increased number of samplers)
                PipelineStage(
                    loss_names=["ppo_loss"],
                    max_stage_steps=ppo_steps - small_batch_steps,
                ),
            ],
            # We use the MultiLinearDecay curve defined by the helper function,
            # setting the learning rate scaling as the square root of the number
            # of nodes. Linear scaling might also works, but we leave that
            # check to the reader.
            lr_scheduler_builder=Builder(
                LambdaLR,
                {
                    "lr_lambda":
                    self.lr_scheduler(
                        small_batch_steps=small_batch_steps,
                        transition_steps=transition_steps,
                        ppo_steps=ppo_steps,
                        lr_scaling=math.sqrt(self.distributed_nodes),
                    )
                },
            ),
        )