コード例 #1
0
class Trainer(object):
    def __init__(self, config):
        self.logger = logging.getLogger("Training")
        self.config = config
        self.start_epoch = 1
        self.monitor = self.config.monitor
        # configuration to monitor model performance and save best
        if self.monitor == 'off':
            self.mnt_mode = 'off'
            self.mnt_best = 0
        else:
            self.mnt_mode, self.mnt_metric = self.monitor.split()
            assert self.mnt_mode in ['min', 'max']

            self.mnt_best = np.inf if self.mnt_mode == 'min' else -np.inf
            self.early_stop = self.config.early_stop
        self.prepare_device()

        self.logger.info("Creating tensorboard writer...")
        self.writer = TensorboardWriter(log_dir=self.config.summary_dir,
                                        logger=self.logger,
                                        enabled=True)

        self.logger.info("Creating data loaders...")
        self.build_data_loader()

        self.logger.info("Creating model architecture...")
        self.build_model()

        self.logger.info("Creating optimizers...")
        self.build_optimizer()

        self.logger.info("Creating losses...")
        self.build_loss()

        self.logger.info("Creating metric trackers...")
        self.build_metrics()

        self.logger.info("Creating checkpoints...")
        self.load_checkpoint(self.config.checkpoint, self.config.resume_epoch)

        self.logger.info("Check parallelism...")
        self.parallelism()

    def build_model(self):
        self.model = TopicClassCNN(
            vocab_size=self.config.vocab_size,  # add 1 for <pad>
            emb_size=self.config.emb_size,
            n_layers=self.config.n_layers,
            attn_heads=self.config.attn_heads,
            dropout=self.config.dropout,
            num_classes=self.config.num_classes)

        # load pretrained bert model
        # checkpoint = '/home/haochen/Projects/cnn-topic-classification/pretrained_bert/uncased_L-12_H-768_A-12/bert_model_sent.pth'
        # state_dict = torch.load(checkpoint)
        # self.model.load_state_dict(state_dict)

    def build_data_loader(self):
        self.vocabulary = Vocabulary()
        self.config.vocab_size = 30522
        assert self.config.vocab_size == 30522, "vocabulary size do not concer with bert config"
        assert self.config.num_classes == self.vocabulary.get_num_topics(
        ), "class number doesn't match topic number"

        self.train_loader = TopicClassDataLoader(
            vocabulary=self.vocabulary,
            split='train',
            batch_size=self.config.batch_size,
            num_workers=self.config.num_workers)

        self.valid_loader = TopicClassDataLoader(
            vocabulary=self.vocabulary,
            split='valid',
            batch_size=self.config.batch_size,
            num_workers=self.config.num_workers)

    def build_optimizer(self):
        self.optimizer = optim.AdamW(self.model.parameters(),
                                     self.config.lr,
                                     weight_decay=self.config.weight_decay)
        self.lr_scheduler = optim.lr_scheduler.ExponentialLR(
            self.optimizer, gamma=self.config.exp_lr_gamma)

    def build_loss(self):
        self.cls_loss = nn.CrossEntropyLoss()
        self.cls_loss.to(self.device)

    def build_metrics(self):
        loss_tags = ['loss', 'acc_1', 'acc_5']
        self.train_metrics = MetricTracker(*loss_tags, writer=self.writer)
        self.val_metrics = MetricTracker(*loss_tags, writer=self.writer)

    def train(self):
        not_improved_count = 0
        self.logger.info("Starting training...")

        for epoch in range(self.start_epoch, self.config.epochs + 1):
            result = self.train_epoch(epoch)

            # save logged informations into log dict
            log = {'epoch': epoch}
            log.update(result)

            # print logged information to the screen
            for key, value in log.items():
                self.logger.info('    {:15s}: {}'.format(str(key), value))

            # evaluate model performance according to configured metric, save best checkpoint as model_best
            best = False
            if self.mnt_mode != 'off':
                try:
                    # check whether model performance improved or not, according to specified metric(mnt_metric)
                    improved = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.mnt_best) or \
                               (self.mnt_mode ==
                                'max' and log[self.mnt_metric] >= self.mnt_best)
                except KeyError:
                    self.logger.warning(
                        "Warning: Metric '{}' is not found. "
                        "Model performance monitoring is disabled.".format(
                            self.mnt_metric))
                    self.mnt_mode = 'off'
                    improved = False

                if improved:
                    self.mnt_best = log[self.mnt_metric]
                    not_improved_count = 0
                    best = True
                else:
                    not_improved_count += 1

                if not_improved_count > self.early_stop:
                    self.logger.info(
                        "Validation performance didn\'t improve for {} epochs. "
                        "Training stops.".format(self.early_stop))
                    break

            self.save_checkpoint('latest', save_best=False)
            if epoch % self.config.save_period == 0:
                self.save_checkpoint(epoch, save_best=best)

            self.config.resume_epoch = epoch

    def train_epoch(self, epoch):
        self.model.train()
        for batch_idx, (sentences, gt_topics) in enumerate(self.train_loader):
            sentences = sentences.to(self.device)
            gt_topics = gt_topics.to(self.device)

            pred_topics = self.model(sentences)
            loss = self.cls_loss(pred_topics, gt_topics)
            acc = accuracy(pred_topics, gt_topics)

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            # add loss summary when update generator to save memory
            self.writer.set_step(
                (epoch - 1) * len(self.train_loader) + batch_idx, mode='train')
            self.train_metrics.update('loss', loss.item())
            self.train_metrics.update('acc_1', acc[0].item())
            self.train_metrics.update('acc_5', acc[1].item())

            # log on console
            if batch_idx % self.config.summary_step == 0:
                self.logger.info(
                    'Train Epoch: {} {} Loss:{:.4f}, Acc_1: {:.2f}, Acc_5: {:.2f}]'
                    .format(epoch, self._progress(batch_idx), loss.item(),
                            acc[0].item(), acc[1].item()))

        self.lr_scheduler.step()
        log = self.train_metrics.result()
        val_log = self.valid_epoch(epoch)
        log.update(**{'val_' + k: v for k, v in val_log.items()})
        return log

    def valid_epoch(self, epoch):
        self.model.eval()
        val_loss = []
        val_acc = []
        val_acc5 = []
        with torch.no_grad():
            for batch_idx, (sentences,
                            gt_topics) in enumerate(self.valid_loader):
                sentences, gt_topics = sentences.to(self.device), gt_topics.to(
                    self.device)

                pred_topics = self.model(sentences)
                loss = self.cls_loss(pred_topics, gt_topics)
                acc = accuracy(pred_topics, gt_topics)

                val_loss.append(loss.item())
                val_acc.append(acc[0].item())
                val_acc5.append(acc[1].item())

        self.writer.set_step(epoch, mode='val')
        self.val_metrics.update('loss', np.mean(val_loss))
        self.val_metrics.update('acc_1', np.mean(val_acc))
        self.val_metrics.update('acc_5', np.mean(val_acc5))

        return self.val_metrics.result()

    def save_checkpoint(self, epoch, save_best):
        state = {
            'epoch': epoch,
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict()
        }
        filename = 'epoch_{}.pth'.format(epoch)
        torch.save(state, os.path.join(self.config.checkpoint_dir, filename))
        if save_best:
            best_path = os.path.join(self.config.checkpoint_dir,
                                     'model_best.pth')
            torch.save(state, best_path)
            self.logger.info("Saving current best: model_best.pth ...")

    def load_checkpoint(self, checkpoint_dir=None, epoch=None):
        if checkpoint_dir is None:
            self.logger.info("Training from scratch...")
            self.model.to(self.device)
            self.start_epoch = 1
            return

        self.logger.info(
            "Loading checkpoints from {}...".format(checkpoint_dir))
        self.start_epoch = epoch + 1
        self.logger.info("Continuing training from epoch {}...".format(epoch))
        filename = 'epoch_{}.pth'.format(epoch)
        checkpoint = torch.load(os.path.join(checkpoint_dir, filename))
        model_to_load = {
            k.replace('module.', ''): v
            for k, v in checkpoint['state_dict'].items()
        }
        self.model.load_state_dict(model_to_load)
        self.model.to(self.device)
        if self.config.mode == 'train':
            self.optimizer.load_state_dict(checkpoint['optimizer'])

    def prepare_device(self):
        self.cuda = torch.cuda.is_available()
        if self.cuda:
            self.device = torch.device("cuda:0")
            self.logger.info("Training will be conducted on GPU")
        else:
            self.device = torch.device("cpu")
            self.logger.info("Training will be conducted on CPU")

        n_gpu = torch.cuda.device_count()
        n_gpu_use = self.config.ngpu

        if n_gpu_use > 0 and n_gpu == 0:
            self.logger.warning(
                "Warning: There\'s no GPU available on this machine,"
                "training will be performed on CPU.")
            n_gpu_use = 0
            self.config.ngpu = n_gpu_use
        if n_gpu_use > n_gpu:
            self.logger.warning(
                "Warning: The number of GPU\'s configured to use is {}, but only {} are available "
                "on this machine.".format(n_gpu_use, n_gpu))
            n_gpu_use = n_gpu
        self.device_ids = list(range(n_gpu_use))

    def parallelism(self):
        if len(self.device_ids) > 1:
            self.logger.info("Using {} GPUs...".format(len(self.device_ids)))
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.device_ids)
        else:
            if self.cuda:
                self.logger.info(
                    "Using only 1 GPU and do not parallelize the models...")
            else:
                self.logger.info("Using CPU...")

    def _progress(self, batch_idx):
        base = '[{}/{} ({:.0f}%)]'
        current = batch_idx
        total = len(self.train_loader)
        return base.format(current, total, 100.0 * current / total)
コード例 #2
0
class CelebATrainer(object):
    def __init__(self, config):
        self.logger = logging.getLogger("Training")
        self.config = config
        self.start_epoch = 1
        self.monitor = self.config.monitor
        # configuration to monitor model performance and save best
        if self.monitor == 'off':
            self.mnt_mode = 'off'
            self.mnt_best = 0
        else:
            self.mnt_mode, self.mnt_metric = self.monitor.split()
            assert self.mnt_mode in ['min', 'max']

            self.mnt_best = np.inf if self.mnt_mode == 'min' else -np.inf
            self.early_stop = self.config.early_stop
        self.prepare_device()

        self.logger.info("Creating tensorboard writer...")
        self.writer = TensorboardWriter(log_dir=self.config.summary_dir,
                                        logger=self.logger,
                                        enabled=True)

        self.logger.info("Creating model architecture...")
        self.build_model()

        self.logger.info("Creating data loaders...")
        self.build_data_loader()

        self.logger.info("Creating optimizers...")
        self.build_optimizer()

        self.logger.info("Creating losses...")
        self.build_loss()

        self.logger.info("Creating metric trackers...")
        self.build_metrics()

        self.logger.info("Creating checkpoints...")
        self.load_checkpoint(self.config.checkpoint, self.config.resume_epoch)

        self.logger.info("Check parallelism...")
        self.parallelism()

        # save config file into model directory
        self.logger.info("Saving config...")
        with open(os.path.join('experiments', config.exp_name, 'config.yaml'),
                  'w') as f:
            yaml.dump(config, f)
            f.close()

    def build_model(self):
        self.model = ResNet34(attr_dim=len(self.config.attrs)
                              if len(self.config.attrs) != 0 else 40)

    def build_data_loader(self):
        self.train_loader = CelebADataLoader(
            data_dir=self.config.data_dir,
            split='train',
            attr_names=self.config.attrs,
            img_size=self.config.image_size,
            batch_size=self.config.batch_size,
            num_workers=self.config.num_workers)

        self.valid_loader = CelebADataLoader(
            data_dir=self.config.data_dir,
            split='valid',
            attr_names=self.config.attrs,
            img_size=self.config.image_size,
            batch_size=self.config.batch_size,
            num_workers=self.config.num_workers)

        self.norm = Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])

    def build_optimizer(self):
        self.optimizer = optim.AdamW(self.model.parameters(),
                                     self.config.lr,
                                     [self.config.beta1, self.config.beta2],
                                     weight_decay=self.config.wd)
        self.lr_scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
            self.optimizer,
            T_0=self.config.cos_restart_t0,
            T_mult=self.config.cos_restart_t_mult,
            eta_min=1e-5)

    def build_loss(self):
        self.cls_loss = nn.BCEWithLogitsLoss()
        self.cls_loss.to(self.device)

    def build_metrics(self):
        loss_tags = ['loss', 'acc_avg']
        acc_tags = ['acc_' + attr_name for attr_name in self.config.attrs]
        loss_tags.extend(acc_tags)
        self.train_metrics = MetricTracker(*loss_tags, writer=self.writer)
        self.val_metrics = MetricTracker(*loss_tags, writer=self.writer)

    def train(self):
        not_improved_count = 0
        self.logger.info("Starting training...")
        # initialize global noise data
        self.noise = torch.zeros([
            self.config.batch_size, 3, self.config.image_size,
            self.config.image_size
        ]).to(self.device)
        self.config.clip_eps /= 255.0
        self.config.fgsm_step /= 255.0

        for epoch in range(self.start_epoch, self.config.epochs + 1):
            result = self.train_epoch(epoch)

            # save logged informations into log dict
            log = {'epoch': epoch}
            log.update(result)

            # print logged information to the screen
            for key, value in log.items():
                self.logger.info('    {:15s}: {}'.format(str(key), value))

            # evaluate model performance according to configured metric, save best checkpoint as model_best
            best = False
            if self.mnt_mode != 'off':
                try:
                    # check whether model performance improved or not, according to specified metric(mnt_metric)
                    improved = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.mnt_best) or \
                               (self.mnt_mode == 'max' and log[self.mnt_metric] >= self.mnt_best)
                except KeyError:
                    self.logger.warning(
                        "Warning: Metric '{}' is not found. "
                        "Model performance monitoring is disabled.".format(
                            self.mnt_metric))
                    self.mnt_mode = 'off'
                    improved = False

                if improved:
                    self.mnt_best = log[self.mnt_metric]
                    not_improved_count = 0
                    best = True
                else:
                    not_improved_count += 1

                if not_improved_count > self.early_stop:
                    self.logger.info(
                        "Validation performance didn\'t improve for {} epochs. "
                        "Training stops.".format(self.early_stop))
                    break

            self.save_checkpoint('latest', save_best=False)
            if epoch % self.config.save_period == 0:
                self.save_checkpoint(epoch, save_best=best)

            self.config.resume_epoch = epoch

    def train_epoch(self, epoch):
        self.model.train()
        # add free adversarial learning to improve generalization ability of the model and robust to noise
        for batch_idx, (images, labels) in enumerate(self.train_loader):
            images = images.to(self.device)
            labels = labels.to(self.device)

            noise = Variable(self.noise[0:images.size(0)],
                             requires_grad=True).to(self.device)
            noisy_images = images + noise
            noisy_images.clamp_(0, 1.0)
            self.norm.do(noisy_images)
            outputs = self.model(noisy_images)
            loss = self.cls_loss(outputs, labels)
            acc = accuracy(outputs, labels)

            self.optimizer.zero_grad()
            loss.backward()

            # Update the noise for the next iteration
            pert = self.config.fgsm_step * torch.sign(noise.grad)
            self.noise[:images.size(0)] += pert.data
            self.noise.clamp_(-self.config.clip_eps, self.config.clip_eps)

            self.optimizer.step()
            self.lr_scheduler.step()

            # add loss summary when update generator to save memory
            self.writer.set_step(
                (epoch - 1) * len(self.train_loader) + batch_idx, mode='train')
            self.train_metrics.update('loss', loss.item())
            self.train_metrics.update('acc_avg', np.mean(acc))
            for acc, attr in zip(acc, self.config.attrs):
                self.train_metrics.update('acc_' + attr, acc)

            # log on console
            if batch_idx % self.config.summary_step == 0:
                self.logger.info(
                    'Train Epoch: {} {} Loss:{:.4f}, Acc: {:.2f}]'.format(
                        epoch, self._progress(batch_idx), loss.item(),
                        np.mean(acc)))

        log = self.train_metrics.result()
        val_log = self.valid_epoch(epoch)
        log.update(**{'val_' + k: v for k, v in val_log.items()})
        return log

    def valid_epoch(self, epoch):
        self.model.eval()
        val_loss = []
        val_acc = []
        with torch.no_grad():
            for batch_idx, (images, labels) in enumerate(self.valid_loader):
                images, labels = images.to(self.device), labels.to(self.device)

                outputs = self.model(images)
                loss = self.cls_loss(outputs, labels)
                acc = accuracy(outputs, labels)

                val_loss.append(loss.item())
                val_acc.append(acc)

        self.writer.set_step(epoch, mode='val')
        self.val_metrics.update('loss', np.mean(val_loss))
        attr_acc = np.mean(val_acc, axis=0)
        self.val_metrics.update('acc_avg', np.mean(attr_acc))
        for acc, attr in zip(attr_acc, self.config.attrs):
            self.val_metrics.update('acc_' + attr, acc)

        return self.val_metrics.result()

    def save_checkpoint(self, epoch, save_best):
        state = {
            'epoch': epoch,
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict()
        }
        filename = 'epoch_{}.pth'.format(epoch)
        torch.save(state, os.path.join(self.config.checkpoint_dir, filename))
        if save_best:
            best_path = os.path.join(self.config.checkpoint_dir,
                                     'model_best.pth')
            torch.save(state, best_path)
            self.logger.info("Saving current best: model_best.pth ...")

    def load_checkpoint(self, checkpoint_dir=None, epoch=None):
        if checkpoint_dir is None:
            self.logger.info("Training from scratch...")
            self.model.to(self.device)
            self.start_epoch = 1
            return

        self.logger.info(
            "Loading checkpoints from {}...".format(checkpoint_dir))
        self.start_epoch = epoch + 1
        self.logger.info("Continuing training from epoch {}...".format(epoch))
        filename = 'epoch_{}.pth'.format(epoch)
        checkpoint = torch.load(os.path.join(checkpoint_dir, filename))
        model_to_load = {
            k.replace('module.', ''): v
            for k, v in checkpoint['state_dict'].items()
        }
        self.model.load_state_dict(model_to_load)
        self.model.to(self.device)
        if self.config.mode == 'train':
            self.optimizer.load_state_dict(checkpoint['optimizer'])

    def prepare_device(self):
        self.cuda = torch.cuda.is_available() & self.config.cuda
        if self.cuda:
            self.device = torch.device("cuda:0")
            self.logger.info("Training will be conducted on GPU")
        else:
            self.device = torch.device("cpu")
            self.logger.info("Training will be conducted on CPU")

        n_gpu = torch.cuda.device_count()
        n_gpu_use = self.config.ngpu

        if n_gpu_use > 0 and n_gpu == 0:
            self.logger.warning(
                "Warning: There\'s no GPU available on this machine,"
                "training will be performed on CPU.")
            n_gpu_use = 0
            self.config.ngpu = n_gpu_use
        if n_gpu_use > n_gpu:
            self.logger.warning(
                "Warning: The number of GPU\'s configured to use is {}, but only {} are available "
                "on this machine.".format(n_gpu_use, n_gpu))
            n_gpu_use = n_gpu
        self.device_ids = list(range(n_gpu_use))

    def parallelism(self):
        if len(self.device_ids) > 1:
            self.logger.info("Using {} GPUs...".format(len(self.device_ids)))
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.device_ids)
        else:
            if self.cuda:
                self.logger.info(
                    "Using only 1 GPU and do not parallelize the models...")
            else:
                self.logger.info("Using CPU...")

    def _progress(self, batch_idx):
        base = '[{}/{} ({:.0f}%)]'
        current = batch_idx
        total = len(self.train_loader)
        return base.format(current, total, 100.0 * current / total)
コード例 #3
0
 def build_metrics(self):
     loss_tags = ['loss', 'acc_1', 'acc_5']
     self.train_metrics = MetricTracker(*loss_tags, writer=self.writer)
     self.val_metrics = MetricTracker(*loss_tags, writer=self.writer)
コード例 #4
0
 def build_metrics(self):
     loss_tags = ['loss', 'acc_avg']
     acc_tags = ['acc_' + attr_name for attr_name in self.config.attrs]
     loss_tags.extend(acc_tags)
     self.train_metrics = MetricTracker(*loss_tags, writer=self.writer)
     self.val_metrics = MetricTracker(*loss_tags, writer=self.writer)