Esempio n. 1
0
    def run(self):
        device = self.model_config.torch_device()
        model = self.model_factory.instantiate().to(device)

        start_epoch = self.storage.last_epoch_idx()

        training_info = TrainingInfo(
            start_epoch_idx=start_epoch,
            run_name=self.model_config.run_name,
        )

        model_state, hidden_state = self.storage.load(training_info)
        model.load_state_dict(model_state)

        model.eval()

        current_char = self.start_letter
        current_char_encoded = self.source.encode_character(self.start_letter)

        generated_text = [current_char]

        state = model.zero_state(1).to(device)

        char_tensor = torch.from_numpy(np.array([current_char_encoded
                                                 ])).view(1, 1).to(device)

        for _ in tqdm.trange(self.length):
            prob_logits, state = model.forward_state(char_tensor, state)

            # Apply temperature to the logits
            prob_logits = F.log_softmax(prob_logits.view(-1).div(
                self.temperature),
                                        dim=0)

            distribution = dist.Categorical(logits=prob_logits)

            char_tensor = distribution.sample().view(1, 1)
            current_char_encoded = char_tensor.item()

            if current_char_encoded == 0:
                # End of sequence marker
                break

            current_char = self.source.decode_character(current_char_encoded)

            generated_text.append(current_char)

        print(
            "============================ START GENERATED TEXT ================================================"
        )
        print(''.join(generated_text))
        print(
            "============================ END GENERATED TEXT ================================================"
        )
Esempio n. 2
0
    def resume_training(self, reinforcer, callbacks, metrics) -> TrainingInfo:
        """ Possibly resume training from a saved state from the storage """
        if self.model_config.continue_training:
            start_epoch = self.storage.last_epoch_idx()
        else:
            start_epoch = 0

        training_info = TrainingInfo(start_epoch_idx=start_epoch,
                                     run_name=self.model_config.run_name,
                                     metrics=metrics,
                                     callbacks=callbacks)

        if start_epoch == 0:
            self.storage.reset(self.model_config.render_configuration())
            training_info.initialize()
            reinforcer.initialize_training(training_info)
        else:
            self.storage.resume(training_info, reinforcer.model)

        return training_info
Esempio n. 3
0
    def resume_training(self, learner, callbacks,
                        metrics) -> (TrainingInfo, dict):
        """ Possibly resume training from a saved state from the storage """
        if self.model_config.reset:
            start_epoch, hidden_state = 0, {}
        else:
            start_epoch, hidden_state = self.storage.resume_learning(
                learner.model)

        training_info = TrainingInfo(start_epoch_idx=start_epoch,
                                     metrics=metrics,
                                     callbacks=callbacks)

        if start_epoch > 0:
            for callback in callbacks:
                callback.load_state_dict(hidden_state)

            training_info.restore(hidden_state)

        return training_info, hidden_state
Esempio n. 4
0
    def resume_training(self, learner, callbacks,
                        metrics) -> (TrainingInfo, dict):
        """ Possibly resume training from a saved state from the storage """
        if self.model_config.reset:
            start_epoch = 0
        else:
            start_epoch = self.storage.last_epoch_idx()

        training_info = TrainingInfo(start_epoch_idx=start_epoch,
                                     run_name=self.model_config.run_name,
                                     metrics=metrics,
                                     callbacks=callbacks)

        if start_epoch == 0:
            self.storage.reset(self.model_config.render_configuration())
            training_info.initialize()
            learner.initialize_training(training_info)
            hidden_state = None
        else:
            hidden_state = self.storage.resume(training_info, learner.model)

        return training_info, hidden_state
Esempio n. 5
0
    def resume_training(self, reinforcer, optimizer, callbacks,
                        metrics) -> TrainingInfo:
        """ Possibly resume training from a saved state from the storage """
        global_epoch_idx = 1

        # TODO(jerry): Implement training resume
        training_info = TrainingInfo(start_epoch_idx=global_epoch_idx,
                                     metrics=metrics,
                                     callbacks=callbacks)

        training_info['run_name'] = self.model_config.run_name

        return training_info
Esempio n. 6
0
    def run(self):
        device = self.model_config.torch_device()

        env = self.env_factory.instantiate_single(preset='record',
                                                  seed=self.model_config.seed)
        model = self.model_factory.instantiate(
            action_space=env.action_space).to(device)

        training_info = TrainingInfo(
            start_epoch_idx=self.storage.last_epoch_idx(),
            run_name=self.model_config.run_name)

        model_state, hidden_state = self.storage.load(training_info)
        model.load_state_dict(model_state)

        model.eval()

        for i in range(self.takes):
            self.record_take(model, env, device, take_number=i + 1)
Esempio n. 7
0
    def run(self):
        device = torch.device(self.model_config.device)

        env = self.env_factory.instantiate(preset='raw')

        if self.frame_history:
            env = FrameStack(env, self.frame_history)

        model = self.model_factory.instantiate(
            action_space=env.action_space).to(device)

        training_info = TrainingInfo(
            start_epoch_idx=self.storage.last_epoch_idx(),
            run_name=self.model_config.run_name)
        self.storage.resume(training_info, model)

        model.eval()

        for i in range(self.takes):
            self.record_take(model, env, device, takenumber=i + 1)
Esempio n. 8
0
                action_space=vec_env.action_space),
            normalize_returns=True,
            discount_factor=0.99),
    )

    # Optimizer helper - A weird regularization settings I've copied from OpenAI code
    adam_optimizer = AdamFactory(lr=[1.0e-4, 1.0e-3, 1.0e-3],
                                 weight_decay=[0.0, 0.0, 0.001],
                                 eps=1.0e-4,
                                 layer_groups=True).instantiate(model)

    # Overall information store for training information
    training_info = TrainingInfo(
        metrics=[
            EpisodeRewardMetric(
                'episode_rewards'),  # Calculate average reward from episode
        ],
        callbacks=[StdoutStreaming()
                   ]  # Print live metrics every epoch to standard output
    )

    # A bit of training initialization bookkeeping...
    training_info.initialize()
    reinforcer.initialize_training(training_info)
    training_info.on_train_begin()

    # Let's make 20 batches per epoch to average metrics nicely
    num_epochs = int(1.0e6 / 2 / 1000)

    # Normal handrolled training loop
    for i in range(1, num_epochs + 1):
        epoch_info = EpochInfo(training_info=training_info,
Esempio n. 9
0
    def run(self):
        """ Run the command with supplied configuration """
        device = torch.device(self.model_config.device)
        learner = Learner(device, self.model.instantiate())

        lr_schedule = interp.interpolate_series(self.start_lr, self.end_lr,
                                                self.num_it,
                                                self.interpolation)

        if self.freeze:
            learner.model.freeze()

        # Optimizer shoudl be created after freeze
        optimizer = self.optimizer_factory.instantiate(learner.model)

        iterator = iter(self.source.train_loader())

        # Metrics to track through this training
        metrics = learner.metrics() + [AveragingNamedMetric("lr")]

        learner.train()

        best_value = None

        training_info = TrainingInfo(start_epoch_idx=0, metrics=metrics)

        # Treat it all as one epoch
        epoch_info = EpochInfo(training_info,
                               global_epoch_idx=1,
                               batches_per_epoch=1,
                               optimizer=optimizer)

        for iteration_idx, lr in enumerate(tqdm.tqdm(lr_schedule)):
            batch_info = BatchInfo(epoch_info, iteration_idx)

            # First, set the learning rate, the same for each parameter group
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

            try:
                data, target = next(iterator)
            except StopIteration:
                iterator = iter(self.source.train_loader())
                data, target = next(iterator)

            learner.train_batch(batch_info, data, target)

            batch_info['lr'] = lr

            # METRIC RECORDING PART
            epoch_info.result_accumulator.calculate(batch_info)

            current_value = epoch_info.result_accumulator.intermediate_value(
                self.metric)

            final_metrics = {
                'epoch_idx': iteration_idx,
                self.metric: current_value,
                'lr': lr
            }

            if best_value is None or current_value < best_value:
                best_value = current_value

            # Stop on divergence
            if self.stop_dv and (np.isnan(current_value) or current_value >
                                 best_value * self.divergence_threshold):
                break

            training_info.history.add(final_metrics)

        frame = training_info.history.frame()

        fig, ax = plt.subplots(1, 2)

        ax[0].plot(frame.index, frame.lr)
        ax[0].set_title("LR Schedule")
        ax[0].set_xlabel("Num iterations")
        ax[0].set_ylabel("Learning rate")

        if self.interpolation == 'logscale':
            ax[0].set_yscale("log", nonposy='clip')

        ax[1].plot(frame.lr, frame[self.metric], label=self.metric)
        # ax[1].plot(frame.lr, frame[self.metric].ewm(com=20).mean(), label=self.metric + ' smooth')
        ax[1].set_title(self.metric)
        ax[1].set_xlabel("Learning rate")
        ax[1].set_ylabel(self.metric)
        # ax[1].legend()

        if self.interpolation == 'logscale':
            ax[1].set_xscale("log", nonposx='clip')

        plt.show()