示例#1
0
文件: dqn.py 项目: labmlai/battleship
    def run(self):
        pytorch_utils.add_model_indicators(self.policy)

        for epoch, (game, arrange) in enumerate(self.games):
            board = Board(arrange)

            # TODO change this
            state = board.get_current_board()

            for iteration in count():
                logger.log('epoch : {}, iteration : {}'.format(epoch, iteration), Color.cyan)

                action = self.get_action(state)
                next_state, reward, done = self.step(board, action.item())

                if done:
                    next_state = None

                self.memory.push(state, action, next_state, reward)

                state = next_state

                self.train()

                if done:
                    tracker.add(iterations=iteration)
                    tracker.save()
                    break

            if epoch % self.target_update == 0:
                self.target.load_state_dict(self.policy.state_dict())

            if self.is_log_parameters:
                pytorch_utils.store_model_indicators(self.policy)
示例#2
0
    def run(self):
        if self.is_log_parameters:
            pytorch_utils.add_model_indicators(self.model)

        for _ in self.training_loop:
            with tracker.namespace('train'):
                self.trainer()
            with tracker.namespace('valid'):
                self.validator()
            if self.is_log_parameters:
                pytorch_utils.store_model_indicators(self.model)
示例#3
0
    def process(self, batch: any, state: any):
        device = self.discriminator.device
        data, target = batch
        data, target = data.to(device), target.to(device)

        # Train the discriminator
        with monit.section("discriminator"):
            for _ in range(self.discriminator_k):
                latent = torch.randn(data.shape[0], 100, device=device)
                if MODE_STATE.is_train:
                    self.discriminator_optimizer.zero_grad()
                logits_true = self.discriminator(data)
                logits_false = self.discriminator(
                    self.generator(latent).detach())
                loss_true, loss_false = self.discriminator_loss(
                    logits_true, logits_false)
                loss = loss_true + loss_false

                # Log stuff
                tracker.add("loss.discriminator.true.", loss_true)
                tracker.add("loss.discriminator.false.", loss_false)
                tracker.add("loss.discriminator.", loss)

                # Train
                if MODE_STATE.is_train:
                    loss.backward()
                    if MODE_STATE.is_log_parameters:
                        pytorch_utils.store_model_indicators(
                            self.discriminator, 'discriminator')
                    self.discriminator_optimizer.step()

        # Train the generator
        with monit.section("generator"):
            latent = torch.randn(data.shape[0], 100, device=device)
            if MODE_STATE.is_train:
                self.generator_optimizer.zero_grad()
            generated_images = self.generator(latent)
            logits = self.discriminator(generated_images)
            loss = self.generator_loss(logits)

            # Log stuff
            tracker.add('generated', generated_images[0:5])
            tracker.add("loss.generator.", loss)

            # Train
            if MODE_STATE.is_train:
                loss.backward()
                if MODE_STATE.is_log_parameters:
                    pytorch_utils.store_model_indicators(
                        self.generator, 'generator')
                self.generator_optimizer.step()

        return {'samples': len(data)}, None
示例#4
0
    def run(self):
        pytorch_utils.add_model_indicators(self.model)

        tracker.set_queue("train.loss", 20, True)
        tracker.set_histogram("valid.loss", True)
        tracker.set_scalar("valid.accuracy", True)

        for _ in self.training_loop:
            self.train()
            self.test()
            if self.is_log_parameters:
                pytorch_utils.store_model_indicators(self.model)
示例#5
0
    def run(self):
        # Training and testing
        pytorch_utils.add_model_indicators(self.model)

        tracker.set_queue("train.loss", 20, True)
        tracker.set_histogram("valid.loss", True)
        tracker.set_scalar("valid.accuracy", True)
        tracker.set_indexed_scalar('valid.sample_loss')
        tracker.set_indexed_scalar('valid.sample_pred')

        test_data = np.array([d[0].numpy() for d in self.valid_dataset])
        experiment.save_numpy("valid.data", test_data)

        for _ in self.training_loop:
            self.train()
            self.valid()
            if self.is_log_parameters:
                pytorch_utils.store_model_indicators(self.model)
示例#6
0
    def step(self, batch: any, batch_idx: BatchIndex):
        data, target = batch[0].to(self.device), batch[1].to(self.device)

        if self.mode.is_train:
            tracker.add_global_step(len(data))

        with self.mode.update(is_log_activations=batch_idx.is_last):
            output = self.model(data)

        loss = self.loss_func(output, target)
        self.accuracy_func(output, target)
        tracker.add("loss.", loss)

        if self.mode.is_train:
            loss.backward()

            self.optimizer.step()
            if batch_idx.is_last:
                pytorch_utils.store_model_indicators(self.model)
            self.optimizer.zero_grad()

        tracker.save()
示例#7
0
 def collect_value(self, model: 'Module'):
     store_model_indicators(model, model_name=self.name)
示例#8
0
文件: rnn.py 项目: whiskey1/samples
    def __log_model_params(self):
        if not self.__is_log_parameters:
            return

        pytorch_utils.store_model_indicators(self.encoder)
示例#9
0
    def __log_model_params(self):
        if not self.__is_log_parameters:
            return

        # Add histograms with model parameter values and gradients
        pytorch_utils.store_model_indicators(self.model)