Ejemplo n.º 1
0
    def fit(
        self,
        train_env: Environment,
        valid_env: Environment,
    ):
        episodes = 0
        with tqdm.tqdm(desc="training") as train_pbar:

            while not train_env.is_closed():
                for i, batch in enumerate(train_env):
                    if isinstance(batch, Observations):
                        observations, rewards = batch, None
                    else:
                        observations, rewards = batch

                    batch_size = observations.x.shape[0]

                    y_pred = train_env.action_space.sample()

                    # If we're at the last batch, it might have a different size, so w
                    # give only the required number of values.
                    if isinstance(y_pred, (np.ndarray, Tensor)):
                        if y_pred.shape[0] != batch_size:
                            y_pred = y_pred[:batch_size]

                    if rewards is None:
                        rewards = train_env.send(y_pred)

                    train_pbar.set_postfix({"Episode": episodes, "Step": i})
                    # train as you usually would.

                episodes += 1
                if self.max_train_episodes and episodes >= self.max_train_episodes:
                    train_env.close()
                    break
Ejemplo n.º 2
0
    def shared_step(
        self,
        batch: Tuple[Observations, Rewards],
        batch_idx: int,
        environment: Environment,
        loss_name: str,
        dataloader_idx: int = None,
        optimizer_idx: int = None,
    ) -> Dict:
        """
        This is the shared step for this 'example' LightningModule.
        Feel free to customize/change it if you want!
        """
        if dataloader_idx is not None:
            assert isinstance(dataloader_idx, int)
            loss_name += f"/{dataloader_idx}"

        # Split the batch into observations and rewards.
        # NOTE: Only in the case of the Supervised settings do we ever get the
        # Rewards at the same time as the Observations.
        # TODO: It would be nice if we could actually do the same things for
        # both sides of the tree here..
        observations, rewards = self.split_batch(batch)

        # FIXME: Remove this, debugging:
        assert isinstance(observations, Observations), observations
        assert isinstance(observations.x, Tensor), observations.shapes
        # Get the forward pass results, containing:
        # - "observation": the augmented/transformed/processed observation.
        # - "representations": the representations for the observations.
        # - "actions": The actions (predictions)
        forward_pass: ForwardPass = self(observations)

        # get the actions from the forward pass:
        actions = forward_pass.actions

        if rewards is None:
            # Get the reward from the environment (the dataloader).
            if self.config.debug and self.config.render:
                environment.render("human")
                # import matplotlib.pyplot as plt
                # plt.waitforbuttonpress(10)

            rewards = environment.send(actions)
            assert rewards is not None

        loss: Loss = self.get_loss(forward_pass, rewards, loss_name=loss_name)
        return {
            "loss": loss.loss,
            "loss_object": loss,
        }
Ejemplo n.º 3
0
    def fit(
        self, train_env: Environment, valid_env: Environment,
    ):
        # Add wrappers, if necessary.
        for wrapper in self.additional_train_wrappers:
            train_env = wrapper(train_env)
        for wrapper in self.additional_valid_wrappers:
            valid_env = wrapper(valid_env)

        train_env = CheckAttributesWrapper(
            train_env, attributes=self.changing_attributes
        )
        valid_env = CheckAttributesWrapper(
            valid_env, attributes=self.changing_attributes
        )

        self.train_env = train_env
        self.valid_env = valid_env
        # TODO: Fix any issues with how the RandomBaselineMethod deals with
        # RL envs
        # return super().fit(train_env, valid_env)
        episodes = 0
        val_interval = 10
        self.train_steps_per_task.append(0)
        self.train_episodes_per_task.append(0)

        while not train_env.is_closed() and (
            episodes < self.max_train_episodes if self.max_train_episodes else True
        ):

            obs = train_env.reset()
            task_labels = obs.task_labels
            if (
                task_labels is None
                or isinstance(task_labels, int)
                or not task_labels.shape
            ):
                task_labels = [task_labels]
            self.observation_task_labels.extend(task_labels)

            done = False
            while not done and not train_env.is_closed():
                actions = train_env.action_space.sample()
                # print(train_env.current_task)
                obs, rew, done, info = train_env.step(actions)
                self.train_steps_per_task[-1] += 1

            episodes += 1
            self.train_episodes_per_task[-1] += 1

            if episodes % val_interval == 0 and not valid_env.is_closed():
                obs = valid_env.reset()
                done = False
                while not done and not valid_env.is_closed():
                    actions = valid_env.action_space.sample()
                    obs, rew, done, info = valid_env.step(actions)

        self.all_train_values.append(self.train_env.values)
        self.all_valid_values.append(self.valid_env.values)
        self.n_fit_calls += 1
Ejemplo n.º 4
0
 def fit(
     self, train_env: Environment, valid_env: Environment,
 ):
     # This method doesn't actually train, so we just return immediately.
     if isinstance(train_env.unwrapped, PassiveEnvironment):
         # Do one 'epoch' only:
         for batch in train_env:
             action = train_env.action_space.sample()
             rewards = train_env.send(action)
     else:
         while not train_env.is_closed():
             obs = train_env.reset()
             done = False
             while not done:
                 obs, rewards, done, info = train_env.reset()
     return
Ejemplo n.º 5
0
    def fit(self, train_env: Environment, valid_env: Environment):
        self.net.train()
        # Simple example training loop, not using the validation loader.
        best_val_loss = np.inf
        best_epoch = 0
        for epoch in range(self.epochs_per_task):
            train_pbar = tqdm.tqdm(train_env, desc=f"Training Epoch {epoch}")
            postfix = {}

            obs: ClassIncrementalSetting.Observations
            rew: ClassIncrementalSetting.Rewards
            for i, (obs, rew) in enumerate(train_pbar):
                self.optim.zero_grad()

                obs = obs.to(device=self.device)
                x = obs.x
                logits = self.net(x)

                if rew is None:
                    # If our online training performance is being measured, we might
                    # need to provide actions before we can get the corresponding
                    # rewards (image labels in this case).
                    y_pred = logits.argmax(1)
                    rew = train_env.send(y_pred)

                rew = rew.to(device=self.device)
                y = rew.y
                loss = F.cross_entropy(logits, y)

                postfix["loss"] = loss.detach().item()
                if self.task > 0 and self.buffer:
                    b_samples = self.buffer.sample(x.size(0))
                    b_logits = self.net(b_samples["x"])
                    loss_replay = F.cross_entropy(b_logits, b_samples["y"])
                    loss += loss_replay
                    postfix["replay loss"] = loss_replay.detach().item()

                loss.backward()
                self.optim.step()

                train_pbar.set_postfix(postfix)

                # Only add new samples to the buffer (only during first epoch).
                if self.buffer and epoch == 0:
                    self.buffer.add_reservoir({"x": x, "y": y, "t": self.task})

            # Validation loop:
            self.net.eval()
            torch.set_grad_enabled(False)
            val_pbar = tqdm.tqdm(valid_env)
            val_pbar.set_description(f"Validation Epoch {epoch}")
            epoch_val_loss = 0.0

            for i, (obs, rew) in enumerate(val_pbar):
                obs = obs.to(device=self.device)
                x = obs.x
                logits = self.net(x)

                if rew is None:
                    y_pred = logits.argmax(-1)
                    rew = valid_env.send(y_pred)

                assert rew is not None
                rew = rew.to(device=self.device)
                y = rew.y
                val_loss = F.cross_entropy(logits, y).item()

                epoch_val_loss += val_loss
                postfix["validation loss"] = epoch_val_loss
                val_pbar.set_postfix(postfix)
            torch.set_grad_enabled(True)

            if epoch_val_loss < best_val_loss:
                best_val_loss = epoch_val_loss
                best_epoch = epoch
            if epoch - best_epoch > self.early_stop_patience:
                print(f"Early stopping at epoch {epoch}.")
                # TODO: Reload the weights from the best epoch.
                break
Ejemplo n.º 6
0
    def fit(self, train_env: Environment, valid_env: Environment):
        self.net.train()
        # Simple example training loop, not using the validation loader.
        best_val_loss = np.inf
        best_epoch = 0

        for epoch in range(self.epochs_per_task):
            train_pbar = tqdm.tqdm(train_env, desc=f"Training Epoch {epoch}")
            postfix = {}

            obs: ClassIncrementalSetting.Observations
            rew: ClassIncrementalSetting.Rewards
            for i, (obs, rew) in enumerate(train_pbar):
                self.optim.zero_grad()

                obs = obs.to(device=self.device)
                x = obs.x

                # FIXME: Batch norm will cause a crash if we pass x with batch_size==1!
                fake_batch = False
                if x.shape[0] == 1:
                    # Pretend like this has batch_size of 2 rather than just 1.
                    x = x.tile([2, *(1 for _ in x.shape[1:])])
                    x[1] += 1  # Just so the two samples aren't identical, otherwise
                    # maybe the batch norm std would be nan or something.
                    fake_batch = True
                logits = self.net(x)
                if fake_batch:
                    logits = logits[:1]  # Drop the 'fake' second item.

                if rew is None:
                    # If our online training performance is being measured, we might
                    # need to provide actions before we can get the corresponding
                    # rewards (image labels in this case).
                    y_pred = logits.argmax(1)
                    rew = train_env.send(y_pred)

                rew = rew.to(device=self.device)
                y = rew.y
                loss = F.cross_entropy(logits, y)

                postfix["loss"] = loss.detach().item()
                if self.task > 0 and self.buffer:
                    b_samples = self.buffer.sample(x.size(0))
                    b_logits = self.net(b_samples["x"])
                    loss_replay = F.cross_entropy(b_logits, b_samples["y"])
                    loss += loss_replay
                    postfix["replay loss"] = loss_replay.detach().item()

                loss.backward()
                self.optim.step()

                train_pbar.set_postfix(postfix)

                # Only add new samples to the buffer (only during first epoch).
                if self.buffer and epoch == 0:
                    self.buffer.add_reservoir({"x": x, "y": y, "t": self.task})

            # Validation loop:
            self.net.eval()
            torch.set_grad_enabled(False)
            val_pbar = tqdm.tqdm(valid_env)
            val_pbar.set_description(f"Validation Epoch {epoch}")
            epoch_val_loss = 0.0
            epoch_val_loss_list: List[float] = []

            for i, (obs, rew) in enumerate(val_pbar):
                obs = obs.to(device=self.device)
                x = obs.x
                logits = self.net(x)

                if rew is None:
                    y_pred = logits.argmax(-1)
                    rew = valid_env.send(y_pred)

                assert rew is not None
                rew = rew.to(device=self.device)
                y = rew.y
                val_loss = F.cross_entropy(logits, y).item()

                epoch_val_loss_list.append(val_loss)
                postfix["validation loss"] = val_loss
                val_pbar.set_postfix(postfix)
            torch.set_grad_enabled(True)
            epoch_val_loss_mean = np.mean(epoch_val_loss_list)

            if epoch_val_loss_mean < best_val_loss:
                best_val_loss = epoch_val_loss_mean
                best_epoch = epoch
            if epoch - best_epoch > self.early_stop_patience:
                print(f"Early stopping at epoch {epoch}.")
                # TODO: Reload the weights from the best epoch.
                break