Esempio n. 1
0
    def save_model(self, val_acc, best_acc, step, best_step):
        if (self.config.max_or_min == 'max' and val_acc > best_acc) or \
           (self.config.max_or_min == 'min' and val_acc < best_acc):
            best_acc = val_acc
            best_step = step
            save_state(self.model, self.head, self.optimizer, self.config, val_acc, step)

        return best_acc, best_step
Esempio n. 2
0
    def run(self):
        self.model.train()
        self.head.train()
        running_loss = 0.
        step = 0
        val_acc = 0.
        val_loss = 0.

        best_step = 0
        best_acc = float('Inf')
        if self.config.max_or_min == 'max':
            best_acc *= -1

        for epoch in range(self.config.epochs):
            train_logger = TrainLogger(self.config.batch_size, self.config.frequency_log)

            if epoch + 1 in self.config.reduce_lr and not self.config.lr_plateau:
                self.reduce_lr()

            for idx, data in enumerate(self.train_loader):
                imgs, labels = data
                imgs = imgs.to(self.config.device)
                labels = labels.to(self.config.device)

                self.optimizer.zero_grad()

                embeddings = self.model(imgs)

                if self.config.attribute == 'recognition':
                    outputs = self.head(embeddings, labels)
                else:
                    outputs = self.head(embeddings)

                if self.weights is not None:
                    loss = self.config.loss(outputs, labels, weight=self.weights)
                else:
                    loss = self.config.loss(outputs, labels)

                loss.backward()
                running_loss += loss.item()

                self.optimizer.step()

                if step % self.tensorboard_loss_every == 0:
                    loss_board = running_loss / self.tensorboard_loss_every
                    self.writer.add_scalar('train_loss', loss_board, step)
                    running_loss = 0.

                if step % self.evaluate_every == 0 and step != 0:
                    if self.config.val_source is not None:
                        val_acc, val_loss = self.evaluate(step)
                        self.model.train()
                        self.head.train()
                        best_acc, best_step = self.save_model(val_acc, best_acc, step, best_step)
                        print(f'Best accuracy: {best_acc:.5f} at step {best_step}')
                    else:
                        save_state(self.model, self.head, self.optimizer, self.config, 0, step)

                train_logger(epoch, self.config.epochs, idx, len(self.train_loader), loss.item())
                step += 1

            if self.config.lr_plateau:
                self.scheduler.step(val_acc)

            if self.config.early_stop:
                self.early_stop(val_acc)
                if self.early_stop.stop:
                    print("Early stopping model...")
                    break

        val_acc, val_loss = self.evaluate(step)
        best_acc = self.save_model(val_acc, best_acc, step, best_step)
        print(f'Best accuracy: {best_acc} at step {best_step}')