コード例 #1
0
    def __init__(self, model, criterion, optimizer, scheduler, metric_ftns,
                 device, num_epoch, grad_clipping, grad_accumulation_steps,
                 early_stopping, validation_frequency, tensorboard,
                 checkpoint_dir, resume_path):
        self.device, device_ids = self._prepare_device(device)
        # self.model = model.to(self.device)

        self.start_epoch = 1
        if resume_path is not None:
            self._resume_checkpoint(resume_path)
        if len(device_ids) > 1:
            # self.model = torch.nn.DataParallel(model, device_ids=device_ids)
            self.model = torch.nn.DataParallel(model)
            # cudnn.benchmark = True
        if use_cuda:
            self.model = model.cuda()
        self.criterion = criterion
        self.metric_ftns = metric_ftns
        self.optimizer = optimizer
        self.num_epoch = num_epoch
        self.scheduler = scheduler
        self.grad_clipping = grad_clipping
        self.grad_accumulation_steps = grad_accumulation_steps
        self.early_stopping = early_stopping
        self.validation_frequency = validation_frequency
        self.checkpoint_dir = checkpoint_dir
        self.best_epoch = 1
        self.best_score = 0
        self.writer = TensorboardWriter(
            os.path.join(checkpoint_dir, 'tensorboard'), tensorboard)
        self.train_metrics = MetricTracker('loss', writer=self.writer)
        self.valid_metrics = MetricTracker(
            'loss',
            *[m.__name__ for m in self.metric_ftns],
            writer=self.writer)
コード例 #2
0
    def build(self):

        # Build Modules
        self.linear_compress = nn.Linear(self.config.input_size,
                                         self.config.hidden_size).cuda()
        self.summarizer = Summarizer(input_size=self.config.hidden_size,
                                     hidden_size=self.config.hidden_size,
                                     num_layers=self.config.num_layers).cuda()
        self.discriminator = Discriminator(
            input_size=self.config.hidden_size,
            hidden_size=self.config.hidden_size,
            num_layers=self.config.num_layers).cuda()
        self.model = nn.ModuleList(
            [self.linear_compress, self.summarizer, self.discriminator])

        if self.config.mode == 'train':
            # Build Optimizers
            self.s_e_optimizer = optim.Adam(
                list(self.summarizer.s_lstm.parameters()) +
                list(self.summarizer.vae.e_lstm.parameters()) +
                list(self.linear_compress.parameters()),
                lr=self.config.lr)
            self.d_optimizer = optim.Adam(
                list(self.summarizer.vae.d_lstm.parameters()) +
                list(self.linear_compress.parameters()),
                lr=self.config.lr)
            self.c_optimizer = optim.Adam(
                list(self.discriminator.parameters()) +
                list(self.linear_compress.parameters()),
                lr=self.config.discriminator_lr)

            self.writer = TensorboardWriter(str(self.config.log_dir))
コード例 #3
0
    def build(self, cuda=True):
        if self.model is None:
            self.model = getattr(models, self.config.model)(self.config)

            if self.config.mode == 'train' and self.config.checkpoint is None:
                print('Parameter initiailization')
                for name, param in self.model.named_parameters():
                    if 'weight_hh' in name:
                        print('\t' + name)
                        nn.init.orthogonal_(param)

                    if 'bias_hh' in name:
                        print('\t' + name)
                        dim = int(param.size(0) / 3)
                        param.data[dim:2 * dim].fill_(2.0)

        if torch.cuda.is_available() and cuda:
            self.model.cuda()

        if self.config.checkpoint:
            self.load_model(self.config.checkpoint)

        if self.is_train:
            self.writer = TensorboardWriter(self.config.logdir)
            self.optimizer = self.config.optimizer(
                filter(lambda p: p.requires_grad, self.model.parameters()),
                lr=self.config.learning_rate)
コード例 #4
0
ファイル: solver.py プロジェクト: herok97/2021-1_Lecture
    def build(self):
        # 내가 추가한 코드
        torch.backends.cudnn.enabled = False

        # 내가 추가한 코드 / GPU 정보
        USE_CUDA = torch.cuda.is_available()
        print(USE_CUDA)
        device = torch.device('cuda:0' if USE_CUDA else 'cpu')
        print('학습을 진행하는 기기:', device)
        print('cuda index:', torch.cuda.current_device())
        print('gpu 개수:', torch.cuda.device_count())
        print('graphic name:', torch.cuda.get_device_name())
        # setting device on GPU if available, else CPU
        print('Using device:', device)

        # Build Modules
        self.linear_compress = nn.Linear(self.config.input_size,
                                         self.config.hidden_size).cuda()
        self.summarizer = Summarizer(input_size=self.config.hidden_size,
                                     hidden_size=self.config.hidden_size,
                                     num_layers=self.config.num_layers).cuda()
        self.discriminator = Discriminator(
            input_size=self.config.hidden_size,
            hidden_size=self.config.hidden_size,
            num_layers=self.config.num_layers).cuda()
        self.model = nn.ModuleList(
            [self.linear_compress, self.summarizer, self.discriminator])

        if self.config.mode == 'train':
            # Build Optimizers
            self.s_e_optimizer = optim.Adam(
                list(self.summarizer.s_lstm.parameters()) +
                list(self.summarizer.vae.e_lstm.parameters()) +
                list(self.linear_compress.parameters()),
                lr=self.config.lr)
            self.d_optimizer = optim.Adam(
                list(self.summarizer.vae.d_lstm.parameters()) +
                list(self.linear_compress.parameters()),
                lr=self.config.lr)
            self.c_optimizer = optim.Adam(
                list(self.discriminator.parameters()) +
                list(self.linear_compress.parameters()),
                lr=self.config.discriminator_lr)
            print(self.model)
            self.model.train()
            # self.model.apply(apply_weight_norm)

            # Overview Parameters
            # VAE만 학습시키기
            # print('Model Parameters')
            # for name, param in self.model.named_parameters():
            #     print('\t' + name + '\t', list(param.size()))
            #     if 'vae' not in name:
            #         param.requires_grad = False
            #     print('\t train: ' + '\t', param.requires_grad)

            # Tensorboard 주석처리 내가 했음
            self.writer = TensorboardWriter(self.config.log_dir)
コード例 #5
0
ファイル: solver.py プロジェクト: NoSyu/CDMM-B
    def build(self, cuda=True):
        if self.model is None:
            self.model = getattr(models, self.config.model)(self.config)

            if self.config.mode == 'train' and self.config.checkpoint is None:
                print('Parameter initiailization')
                for name, param in self.model.named_parameters():
                    if 'weight_hh' in name:
                        print('\t' + name)
                        nn.init.orthogonal_(param)

                    if 'bias_hh' in name:
                        print('\t' + name)
                        dim = int(param.size(0) / 3)
                        param.data[dim:2 * dim].fill_(2.0)

        if torch.cuda.is_available() and cuda:
            self.model.cuda()

        print('Model Parameters')
        for name, param in self.model.named_parameters():
            print('\t' + name + '\t', list(param.size()))

        if self.config.checkpoint:
            self.load_model(self.config.checkpoint)

        if self.is_train:
            self.writer = TensorboardWriter(self.config.logdir)
            if self.config.optimizer is None:
                # AdamW
                no_decay = ['bias', 'LayerNorm.weight']
                optimizer_grouped_parameters = [{
                    'params': [
                        p for n, p in self.model.named_parameters()
                        if not any(nd in n for nd in no_decay)
                    ],
                    'weight_decay':
                    0.01
                }, {
                    'params': [
                        p for n, p in self.model.named_parameters()
                        if any(nd in n for nd in no_decay)
                    ],
                    'weight_decay':
                    0.0
                }]
                self.optimizer = AdamW(optimizer_grouped_parameters,
                                       lr=self.config.learning_rate)
            else:
                self.optimizer = self.config.optimizer(
                    filter(lambda p: p.requires_grad, self.model.parameters()),
                    lr=self.config.learning_rate)
コード例 #6
0
    def __init__(
        self,
        caption_model: str,
        epochs: int,
        device: torch.device,
        word_map: Dict[str, int],
        rev_word_map: Dict[int, str],
        start_epoch: int,
        epochs_since_improvement: int,
        best_bleu4: float,
        train_loader: DataLoader,
        val_loader: DataLoader,
        encoder: nn.Module,
        decoder: nn.Module,
        encoder_optimizer: optim.Optimizer,
        decoder_optimizer: optim.Optimizer,
        loss_function: nn.Module,
        grad_clip: float,
        tau: float,
        fine_tune_encoder: bool,
        tensorboard: bool = False,
        log_dir: Optional[str] = None
    ) -> None:
        self.device = device  # GPU / CPU

        self.caption_model = caption_model
        self.epochs = epochs
        self.word_map = word_map
        self.rev_word_map = rev_word_map

        self.start_epoch = start_epoch
        self.epochs_since_improvement = epochs_since_improvement
        self.best_bleu4 = best_bleu4

        self.train_loader =  train_loader
        self.val_loader = val_loader
        self.encoder = encoder
        self.decoder = decoder
        self.encoder_optimizer = encoder_optimizer
        self.decoder_optimizer = decoder_optimizer
        self.loss_function = loss_function

        self.tau = tau
        self.grad_clip = grad_clip
        self.fine_tune_encoder = fine_tune_encoder

        self.print_freq = 100  # print training/validation stats every __ batches
        # setup visualization writer instance
        self.writer = TensorboardWriter(log_dir, tensorboard)
        self.len_epoch = len(self.train_loader)
コード例 #7
0
    def build(self, cuda=True):
        if self.model is None:
            self.model = getattr(models, self.config.model)(self.config)

        if torch.cuda.is_available() and cuda:
            self.model.cuda()

        if self.config.checkpoint:
            self.load_model(self.config.checkpoint)

        if self.is_train:
            self.writer = TensorboardWriter(self.config.logdir)
            self.optimizer = self.config.optimizer(
                filter(lambda p: p.requires_grad, self.model.parameters()),
                lr=self.config.learning_rate)
コード例 #8
0
    def build(self):

        # Build Modules
        # self.summarizer = simple_encoder_LSTM(
        #     input_size=self.config.input_size,
        #     hidden_size=self.config.hidden_size,
        #     num_layers=self.config.num_layers).cuda()
        self.summarizer = attentive_encoder_LSTM(
            input_size=self.config.input_size,
            hidden_size=self.config.hidden_size,
            num_layers=self.config.num_layers).cuda()
        # self.summarizer = attentive_encoder_decoder_LSTM(
        #     input_size=self.config.input_size,
        #     hidden_size=self.config.hidden_size,
        #     num_layers=self.config.num_layers).cuda()

        if self.config.mode == 'train':
            # Build Optimizers
            self.optimizer = optim.Adam(self.summarizer.parameters(),
                                        lr=self.config.lr)

            self.summarizer.train()

            # Tensorboard
            self.writer = TensorboardWriter(self.config.log_dir)
コード例 #9
0
    def build(self, cuda=True):

        if self.model is None:
            self.model = getattr(models, self.config.model)(self.config)

            # orthogonal initialiation for hidden weights
            # input gate bias for GRUs
            if self.config.mode == 'train' and self.config.checkpoint is None:
                print('Parameter initiailization')
                for name, param in self.model.named_parameters():
                    if 'weight_hh' in name:
                        print('\t' + name)
                        nn.init.orthogonal_(param)

                    # bias_hh is concatenation of reset, input, new gates
                    # only set the input gate bias to 2.0
                    if 'bias_hh' in name:
                        print('\t' + name)
                        dim = int(param.size(0) / 3)
                        param.data[dim:2 * dim].fill_(2.0)

        # if torch.cuda.is_available() and cuda:
        #    self.model.cuda()

        if torch.cuda.is_available() and cuda:
            self.model = self.model.cuda()
        """
        if torch.cuda.device_count() > 1:
            device_ids = [0, 1, 2, 3]
            self.model = nn.DataParallel(self.model, device_ids=device_ids)
        """

        # Overview Parameters
        print('Model Parameters')
        for name, param in self.model.named_parameters():
            print('\t' + name + '\t', list(param.size()))

        if self.config.checkpoint:
            self.load_model(self.config.checkpoint)

        if self.is_train:
            self.writer = TensorboardWriter(self.config.logdir)
            self.optimizer = self.config.optimizer(
                filter(lambda p: p.requires_grad, self.model.parameters()),
                lr=self.config.learning_rate)
コード例 #10
0
    def build(self):

        # Build Modules
        self.linear_compress = nn.Linear(self.config.input_size,
                                         self.config.hidden_size).cuda()
        self.summarizer = Summarizer(input_size=self.config.hidden_size,
                                     hidden_size=self.config.hidden_size,
                                     num_layers=self.config.num_layers).cuda()
        self.discriminator = Discriminator(
            input_size=self.config.hidden_size,
            hidden_size=self.config.hidden_size,
            num_layers=self.config.num_layers).cuda()
        self.model = nn.ModuleList(
            [self.linear_compress, self.summarizer, self.discriminator])

        if self.config.mode == 'train':
            # Build Optimizers
            self.s_e_optimizer = optim.Adam(
                list(self.summarizer.s_lstm.parameters()) +
                list(self.summarizer.vae.e_lstm.parameters()) +
                list(self.linear_compress.parameters()),
                lr=self.config.lr)
            self.d_optimizer = optim.Adam(
                list(self.summarizer.vae.d_lstm.parameters()) +
                list(self.linear_compress.parameters()),
                lr=self.config.lr)
            self.c_optimizer = optim.Adam(
                list(self.discriminator.parameters()) +
                list(self.linear_compress.parameters()),
                lr=self.config.discriminator_lr)

            self.model.train()
            # self.model.apply(apply_weight_norm)

            # Overview Parameters
            # print('Model Parameters')
            # for name, param in self.model.named_parameters():
            #     print('\t' + name + '\t', list(param.size()))

            # Tensorboard
            import ipdb
            ipdb.set_trace()
            self.writer = TensorboardWriter(self.config.log_dir)
コード例 #11
0
    def __init__(self,
                 num_epochs: int,
                 start_epoch: int,
                 train_loader: DataLoader,
                 model: nn.Module,
                 model_name: str,
                 loss_function: nn.Module,
                 optimizer,
                 lr_decay: float,
                 dataset_name: str,
                 word_map: Dict[str, int],
                 grad_clip=Optional[None],
                 print_freq: int = 100,
                 checkpoint_path: Optional[str] = None,
                 checkpoint_basename: str = 'checkpoint',
                 tensorboard: bool = False,
                 log_dir: Optional[str] = None) -> None:
        self.num_epochs = num_epochs
        self.start_epoch = start_epoch
        self.train_loader = train_loader

        self.model = model
        self.model_name = model_name
        self.loss_function = loss_function
        self.optimizer = optimizer
        self.lr_decay = lr_decay

        self.dataset_name = dataset_name
        self.word_map = word_map
        self.print_freq = print_freq
        self.grad_clip = grad_clip

        self.checkpoint_path = checkpoint_path
        self.checkpoint_basename = checkpoint_basename

        # setup visualization writer instance
        self.writer = TensorboardWriter(log_dir, tensorboard)
        self.len_epoch = len(self.train_loader)
コード例 #12
0
    def __init__(self, model, criterion, metric_ftns, optimizer, config):
        self.config = config
        self.logger = config.get_logger('trainer',
                                        config['trainer']['verbosity'])

        # setup GPU device if available, move model into configured device
        self.device, self.device_ids = self._prepare_device(
            config['num_gpu'], config['main_device_id'], config['device_id'])
        self.model = model.cuda(self.device)
        if len(self.device_ids) > 1:
            self.model = torch.nn.DataParallel(model,
                                               device_ids=self.device_ids)

        self.criterion = criterion
        self.metric_ftns = metric_ftns
        self.optimizer = optimizer

        cfg_trainer = config['trainer']
        self.epochs = cfg_trainer['epochs']
        self.save_period = cfg_trainer['save_period']
        self.monitor = cfg_trainer.get('monitor', 'off')
        self.add_graph = cfg_trainer.get('add_graph', False)

        # configuration to monitor model performance and save best
        if self.monitor == 'off':
            self.mnt_mode = 'off'
            self.mnt_best = 0
        else:
            self.mnt_mode, self.mnt_metric = self.monitor.split()
            assert self.mnt_mode in ['min', 'max']

            self.mnt_best = inf if self.mnt_mode == 'min' else -inf
            self.early_stop = cfg_trainer.get('early_stop', inf)

        self.start_epoch = 1
        self.best_epoch = self.start_epoch
        self.checkpoint_dir = config.save_dir
        self.logger_dir = config.log_dir
        self.ner_type = config.config['experiment_name'].split('_')[-1]

        # setup visualization writer instance
        self.writer = TensorboardWriter(config.log_dir, self.logger,
                                        cfg_trainer['tensorboard'])

        if config.resume is not None:
            self._resume_checkpoint(config.resume)
コード例 #13
0
ファイル: train.py プロジェクト: NoAchache/TextBoxGan
class Trainer(object):
    """ Train the model. The different configs can be tuned in config/config. """

    def __init__(self):

        self.batch_size = cfg.batch_size
        self.strategy = cfg.strategy
        self.max_steps = cfg.max_steps
        self.summary_steps_frequency = cfg.summary_steps_frequency
        self.image_summary_step_frequency = cfg.image_summary_step_frequency
        self.save_step_frequency = cfg.save_step_frequency
        self.log_dir = cfg.log_dir

        self.validation_step_frequency = cfg.validation_step_frequency
        self.tensorboard_writer = TensorboardWriter(self.log_dir)
        # set optimizer params
        self.g_opt = self.update_optimizer_params(cfg.g_opt)
        self.d_opt = self.update_optimizer_params(cfg.d_opt)
        self.pl_mean = tf.Variable(
            initial_value=0.0,
            name="pl_mean",
            trainable=False,
            synchronization=tf.VariableSynchronization.ON_READ,
            aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
        )
        self.training_data_loader = TrainingDataLoader()
        self.validation_data_loader = ValidationDataLoader("validation_corpus.txt")
        self.model_loader = ModelLoader()
        # create model: model and optimizer must be created under `strategy.scope`
        (
            self.discriminator,
            self.generator,
            self.g_clone,
        ) = self.model_loader.initiate_models()

        # set optimizers
        self.d_optimizer = tf.keras.optimizers.Adam(
            self.d_opt["learning_rate"],
            beta_1=self.d_opt["beta1"],
            beta_2=self.d_opt["beta2"],
            epsilon=self.d_opt["epsilon"],
        )
        self.g_optimizer = tf.keras.optimizers.Adam(
            self.g_opt["learning_rate"],
            beta_1=self.g_opt["beta1"],
            beta_2=self.g_opt["beta2"],
            epsilon=self.g_opt["epsilon"],
        )
        self.ocr_optimizer = tf.keras.optimizers.Adam(
            self.g_opt["learning_rate"],
            beta_1=self.g_opt["beta1"],
            beta_2=self.g_opt["beta2"],
            epsilon=self.g_opt["epsilon"],
        )
        self.ocr_loss_weight = cfg.ocr_loss_weight

        self.aster_ocr = AsterInferer()

        self.training_step = TrainingStep(
            self.generator,
            self.discriminator,
            self.aster_ocr,
            self.g_optimizer,
            self.ocr_optimizer,
            self.d_optimizer,
            self.g_opt["reg_interval"],
            self.d_opt["reg_interval"],
            self.pl_mean,
        )

        self.validation_step = ValidationStep(self.g_clone, self.aster_ocr)

        self.manager = self.model_loader.load_checkpoint(
            ckpt_kwargs={
                "d_optimizer": self.d_optimizer,
                "g_optimizer": self.g_optimizer,
                "ocr_optimizer": self.ocr_optimizer,
                "discriminator": self.discriminator,
                "generator": self.generator,
                "g_clone": self.g_clone,
                "pl_mean": self.pl_mean,
            },
            model_description="Full model",
            expect_partial=False,
            ckpt_dir=cfg.ckpt_dir,
            max_to_keep=cfg.num_ckpts_to_keep,
        )

    @staticmethod
    def update_optimizer_params(params: dict):
        """
        Updates the optimizer configurations.

        Parameters
        ----------
        params: Configs of the optimizer

        Returns
        -------
        Updated configuration of the optimizer

        """
        params_copy = params.copy()
        mb_ratio = params_copy["reg_interval"] / (params_copy["reg_interval"] + 1)
        params_copy["learning_rate"] = params_copy["learning_rate"] * mb_ratio
        params_copy["beta1"] = params_copy["beta1"] ** mb_ratio
        params_copy["beta2"] = params_copy["beta2"] ** mb_ratio
        return params_copy

    def train(self):
        """
        Main training loop.

        """
        train_dataset = self.training_data_loader.load_dataset(
            batch_size=self.batch_size
        )

        train_dataset = self.strategy.experimental_distribute_dataset(train_dataset)

        validation_dataset = self.validation_data_loader.load_dataset(
            batch_size=self.batch_size
        )
        validation_dataset = self.strategy.experimental_distribute_dataset(
            validation_dataset
        )

        # start actual training
        print("Start Training")

        # setup loss trackers

        train_losses = [
            "reg_g_loss",
            "g_loss",
            "pl_penalty",
            "ocr_loss",
            "reg_d_loss",
            "d_loss",
            "r1_penalty",
        ]

        loss_trackers = [
            LossTracker(train_losses, print_step, log_losses)
            for print_step, log_losses in zip(
                self.summary_steps_frequency["print_steps"],
                self.summary_steps_frequency["log_losses"],
            )
        ]

        validation_tracker = LossTracker(["validation_ocr_loss"])
        # start training
        for real_images, ocr_image, input_words, ocr_labels in train_dataset:
            step = self.g_optimizer.iterations.numpy()

            # g train step
            do_r1_reg = True if (step + 1) % self.d_opt["reg_interval"] == 0 else False
            do_pl_reg = True if (step + 1) % self.g_opt["reg_interval"] == 0 else False

            if (
                step > 5000
            ):  # Set the ocr_loss_weight (close) to 0 at the beginning of the training since it is too early
                # to have a text to read from
                ocr_loss_weight = self.ocr_loss_weight

            else:
                ocr_loss_weight = 1e-8

            (gen_losses, disc_losses, ocr_loss,) = self.training_step.dist_train_step(
                real_images,
                ocr_image,
                input_words,
                ocr_labels,
                do_r1_reg,
                do_pl_reg,
                ocr_loss_weight,
            )

            reg_g_loss, g_loss, pl_penalty = gen_losses
            reg_d_loss, d_loss, r1_penalty = disc_losses

            # update g_clone
            self.g_clone.set_as_moving_average_of(self.generator)

            # get current step
            step = self.g_optimizer.iterations.numpy()

            losses_dict = {
                "reg_g_loss": reg_g_loss,
                "g_loss": g_loss,
                "pl_penalty": pl_penalty,
                "ocr_loss": ocr_loss,
                "reg_d_loss": reg_d_loss,
                "d_loss": d_loss,
                "r1_penalty": r1_penalty,
            }

            for loss_tracker in loss_trackers:
                loss_tracker.increment_losses(losses_dict)

            # save every self.save_step
            if step % self.save_step_frequency == 0:
                self.manager.save(checkpoint_number=step)

            # save every self.image_summary_step
            if step % self.image_summary_step_frequency == 0:
                self.tensorboard_writer.log_images(
                    input_words, self.g_clone, self.aster_ocr, step
                )

            if step % self.validation_step_frequency == 0:
                for input_words, ocr_labels in validation_dataset:
                    ocr_loss = self.validation_step.dist_validation_step(
                        input_words, ocr_labels
                    )
                    validation_tracker.increment_losses(
                        {"validation_ocr_loss": ocr_loss}
                    )

                self.tensorboard_writer.log_scalars(validation_tracker.losses, step)
                validation_tracker.print_losses(step)
                validation_tracker.reinitialize_tracker()

            # print every self.print_steps
            for loss_tracker in loss_trackers:
                if step % loss_tracker.print_step == 0:
                    loss_tracker.print_losses(step)
                    if loss_tracker.log_losses:
                        self.tensorboard_writer.log_scalars(loss_tracker.losses, step)
                    loss_tracker.reinitialize_tracker()
            if step == self.max_steps:
                break

        # save last checkpoint
        step = self.g_optimizer.iterations.numpy()
        self.manager.save(checkpoint_number=step)
        return
コード例 #14
0
    def build(self):
        # Build Modules
        # self.device = torch.device('cuda:0,1')
        self.embedding = nn.Embedding(self.config.vocab_size,
                                      self.config.wemb_size,
                                      padding_idx=0)

        if True:
            weights_matrix = torch.FloatTensor(
                pickle.load(open(p.word_vec_pkl, 'rb')))
            self.embedding.from_pretrained(weights_matrix, freeze=False)
            self.embedding.weight.requires_grad = True

        self.w_hr_fw = nn.ModuleList(self.config.num_layers * [
            nn.Linear(
                self.config.hidden_size, self.config.kwd_size, bias=False)
        ])
        self.w_hr_bw = nn.ModuleList(self.config.num_layers * [
            nn.Linear(
                self.config.hidden_size, self.config.kwd_size, bias=False)
        ])

        self.w_wr = nn.Linear(self.config.wemb_size,
                              self.config.kwd_size,
                              bias=False)
        self.w_ho_fw = nn.Sequential(
            nn.Linear(self.config.hidden_size * self.config.num_layers,
                      self.config.vocab_size),
            #             nn.LogSoftmax(dim=-1)
        )
        self.w_ho_bw = nn.Linear(
            self.config.hidden_size * self.config.num_layers,
            self.config.vocab_size)
        self.sc_rnn_fw = SCLSTM_MultiCell(self.config.num_layers,
                                          self.config.wemb_size,
                                          self.config.hidden_size,
                                          self.config.kwd_size,
                                          dropout=self.config.drop_rate)

        self.sc_rnn_bw = SCLSTM_MultiCell(self.config.num_layers,
                                          self.config.wemb_size,
                                          self.config.hidden_size,
                                          self.config.kwd_size,
                                          dropout=self.config.drop_rate)

        self.model = nn.ModuleList([
            self.w_hr_fw, self.w_hr_bw, self.w_wr, self.w_ho_fw, self.w_ho_bw,
            self.sc_rnn_fw, self.sc_rnn_bw
        ])

        self.criterion = nn.CrossEntropyLoss(reduction='none')

        with torch.no_grad():
            self.hc_list_init = (Variable(torch.zeros(self.config.num_layers,
                                                      self.config.batch_size,
                                                      self.config.hidden_size),
                                          requires_grad=False),
                                 Variable(torch.zeros(self.config.num_layers,
                                                      self.config.batch_size,
                                                      self.config.hidden_size),
                                          requires_grad=False))

        #--- Init dirs for output ---
        self.current_time = datetime.now().strftime('%b%d_%H-%M-%S')
        if self.config.mode == 'train':
            # Overview Parameters
            print('Init Model Parameters')
            for name, param in self.model.named_parameters():
                print('\t' + name + '\t', list(param.size()))
                if param.data.ndimension() >= 2:
                    nn.init.xavier_uniform_(param.data)
                else:
                    nn.init.zeros_(param.data)

            # Tensorboard
            self.writer = TensorboardWriter(p.tb_dir + self.current_time)
            # Add emb-layer
            self.model.train()
            # create dir
            #             self.res_dir = p.result_path.format(p.dataname, self.current_time) # result dir
            self.cp_dir = p.check_point.format(
                p.dataname, self.current_time)  # checkpoint dir
            #             os.makedirs(self.res_dir)
            os.makedirs(self.cp_dir)

        #--- Setup output file ---
        self.out_file = open(
            p.out_result_dir.format(p.dataname, self.current_time), 'w')

        self.model.append(self.embedding)
        #         self.model.to(self.device)
        # Build Optimizers
        self.optimizer = optim.Adam(list(self.model.parameters()),
                                    lr=self.config.lr)
        print(self.model)
コード例 #15
0
class Solver(object):
    def __init__(self,
                 config=None,
                 train_loader=None,
                 test_loader=None,
                 valid_loader=None):
        """Class that Builds, Trains and Evaluates SCLSTM model"""
        self.config = config
        self.train_loader = train_loader
        self.test_loader = test_loader
        os.environ["CUDA_VISIBLE_DEVICES"] = self.config.gpu
        self.vocab = pickle.load(open(p.word_vocab_pkl, 'rb'))
        self.kvoc = pickle.load(open(p.kwd_pkl, 'rb'))
        self.i2w = {i: w for i, w in enumerate(self.vocab)}  # index to vocab
        self.i2k = {i: k for i, k in enumerate(self.kvoc)}  # index to keyword
        self.w2i = {w: i for i, w in self.i2w.items()}

    def build(self):
        # Build Modules
        # self.device = torch.device('cuda:0,1')
        self.embedding = nn.Embedding(self.config.vocab_size,
                                      self.config.wemb_size,
                                      padding_idx=0)

        if True:
            weights_matrix = torch.FloatTensor(
                pickle.load(open(p.word_vec_pkl, 'rb')))
            self.embedding.from_pretrained(weights_matrix, freeze=False)
            self.embedding.weight.requires_grad = True

        self.w_hr_fw = nn.ModuleList(self.config.num_layers * [
            nn.Linear(
                self.config.hidden_size, self.config.kwd_size, bias=False)
        ])
        self.w_hr_bw = nn.ModuleList(self.config.num_layers * [
            nn.Linear(
                self.config.hidden_size, self.config.kwd_size, bias=False)
        ])

        self.w_wr = nn.Linear(self.config.wemb_size,
                              self.config.kwd_size,
                              bias=False)
        self.w_ho_fw = nn.Sequential(
            nn.Linear(self.config.hidden_size * self.config.num_layers,
                      self.config.vocab_size),
            #             nn.LogSoftmax(dim=-1)
        )
        self.w_ho_bw = nn.Linear(
            self.config.hidden_size * self.config.num_layers,
            self.config.vocab_size)
        self.sc_rnn_fw = SCLSTM_MultiCell(self.config.num_layers,
                                          self.config.wemb_size,
                                          self.config.hidden_size,
                                          self.config.kwd_size,
                                          dropout=self.config.drop_rate)

        self.sc_rnn_bw = SCLSTM_MultiCell(self.config.num_layers,
                                          self.config.wemb_size,
                                          self.config.hidden_size,
                                          self.config.kwd_size,
                                          dropout=self.config.drop_rate)

        self.model = nn.ModuleList([
            self.w_hr_fw, self.w_hr_bw, self.w_wr, self.w_ho_fw, self.w_ho_bw,
            self.sc_rnn_fw, self.sc_rnn_bw
        ])

        self.criterion = nn.CrossEntropyLoss(reduction='none')

        with torch.no_grad():
            self.hc_list_init = (Variable(torch.zeros(self.config.num_layers,
                                                      self.config.batch_size,
                                                      self.config.hidden_size),
                                          requires_grad=False),
                                 Variable(torch.zeros(self.config.num_layers,
                                                      self.config.batch_size,
                                                      self.config.hidden_size),
                                          requires_grad=False))

        #--- Init dirs for output ---
        self.current_time = datetime.now().strftime('%b%d_%H-%M-%S')
        if self.config.mode == 'train':
            # Overview Parameters
            print('Init Model Parameters')
            for name, param in self.model.named_parameters():
                print('\t' + name + '\t', list(param.size()))
                if param.data.ndimension() >= 2:
                    nn.init.xavier_uniform_(param.data)
                else:
                    nn.init.zeros_(param.data)

            # Tensorboard
            self.writer = TensorboardWriter(p.tb_dir + self.current_time)
            # Add emb-layer
            self.model.train()
            # create dir
            #             self.res_dir = p.result_path.format(p.dataname, self.current_time) # result dir
            self.cp_dir = p.check_point.format(
                p.dataname, self.current_time)  # checkpoint dir
            #             os.makedirs(self.res_dir)
            os.makedirs(self.cp_dir)

        #--- Setup output file ---
        self.out_file = open(
            p.out_result_dir.format(p.dataname, self.current_time), 'w')

        self.model.append(self.embedding)
        #         self.model.to(self.device)
        # Build Optimizers
        self.optimizer = optim.Adam(list(self.model.parameters()),
                                    lr=self.config.lr)
        print(self.model)

    def load_model(self, ep):
        _fname = (self.cp_dir if self.config.mode == 'train' else
                  self.config.resume_dir) + 'chk_point_{}.pth'.format(ep)
        if os.path.isfile(_fname):
            print("=> loading checkpoint '{}'".format(_fname))
            if self.config.load_cpu:
                checkpoint = torch.load(_fname,
                                        map_location=lambda storage, loc:
                                        storage)  # load into cpu-mode
            else:
                checkpoint = torch.load(_fname)  # gpu-mode
            self.start_epoch = checkpoint['epoch']
            # checkpoint['state_dict'].pop('1.s_lstm.out.0.bias',None) # remove bias in selector
            self.model.load_state_dict(checkpoint['state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer'][0])
        else:
            print("=> no checkpoint found at '{}'".format(_fname))

    def _zero_grads(self):
        self.optimizer.zero_grad()

    def save_checkpoint(self, state, filename):
        torch.save(state, filename)

    def get_norm_grad(self, module, norm_type=2):
        total_norm = 0
        for name, param in module.named_parameters():
            if param.grad is not None:
                total_norm += torch.sum(torch.pow(param.grad.view(-1), 2))
        return torch.sqrt(total_norm).data

    def one_step_fw(self, w_t, y_t, hc_list, d_t, rnn_model, w_hr, w_ho):
        h_tm1, _ = hc_list
        #--- Keyword detector ---
        res_hr = sum(
            [w_hr[l](h_tm1[l]) for l in range(self.config.num_layers)])
        r_t = torch.sigmoid(self.w_wr(w_t) + self.config.alpha * res_hr)
        d_t = r_t * d_t
        flat_h, hc_list = rnn_model(w_t, hc_list, d_t)

        with torch.no_grad():
            mask = Variable((y_t != 0).float(), requires_grad=False)
            assert not torch.isnan(mask).any()
        pred = w_ho(flat_h)
        llk_step = torch.mean(self.criterion(pred, y_t) * mask)
        l1_step = torch.mean(torch.sum(torch.abs(d_t), dim=-1))
        assert not torch.isnan(llk_step).any()
        assert not torch.isnan(l1_step).any()
        return llk_step, l1_step, pred, hc_list, d_t

    def train_epoch(self):
        loss_list = []
        l1_list = []
        fw_list, bw_list = [], []
        for batch_i, doc_features in enumerate(
                tqdm(self.train_loader,
                     desc='Batch',
                     dynamic_ncols=True,
                     ascii=True)):
            self._zero_grads()
            doc, kwd = doc_features
            with torch.no_grad():
                var_doc = Variable(doc, requires_grad=False)
                var_kwd = Variable(kwd, requires_grad=False)

            doc_emb = self.embedding(var_doc)  # get word-emb

            #--- Word generation ---
            step_loss = []
            step_l1 = []

            #--- FW Stage ---
            hc_list = self.hc_list_init
            d_t = var_kwd
            for t in range(p.MAX_DOC_LEN - 1):
                w_t = doc_emb[:, t, :]
                y_t = var_doc[:, t + 1]
                #                 h_tm1, _ = hc_list

                #                 #--- Keyword detector ---
                #                 res_hr = sum([self.w_hr[l](h_tm1[l]) for l in range(self.config.num_layers)])
                #                 r_t = torch.sigmoid(self.w_wr(w_t) + self.config.alpha*res_hr)
                #                 d_t = r_t*d_t
                # #                 print hc_list[0].shape, w_t.shape, d_t.shape
                #                 flat_h, hc_list = self.sc_rnn(w_t, hc_list, d_t)

                #                 #--- Log LLK ---
                #                 with torch.no_grad():
                #                     mask = Variable((y_t!=0).float(), requires_grad=False)
                #                     assert not torch.isnan(mask).any()
                #                 pred = self.w_ho(flat_h)
                #                 llk_step = torch.mean(self.criterion(pred, y_t) * mask)
                #                 l1_step = torch.mean(torch.sum(torch.abs(d_t), dim=-1))

                #                 assert not torch.isnan(llk_step).any()
                #                 assert not torch.isnan(l1_step).any()
                llk_step, l1_step, pred, hc_list, d_t = self.one_step_fw(
                    w_t, y_t, hc_list, d_t, self.sc_rnn_fw, self.w_hr_fw,
                    self.w_ho_fw)
                p_pred, w_pred = torch.max(nn.LogSoftmax(dim=-1)(pred), dim=-1)
                #                 print [(self.i2w[i], v) for i, v in zip(w_pred.detach().cpu().numpy(), p_pred.detach().cpu().numpy())]

                step_loss.append(llk_step)
                step_l1.append(l1_step)

            fw_loss = sum(step_loss)
            fw_l1 = sum(step_l1) * self.config.eta
            batch_loss = fw_loss + fw_l1
            batch_loss.backward(retain_graph=True)

            #--- BW Stage ---
            torch.cuda.empty_cache()
            step_loss = []
            step_l1 = []
            hc_list = self.hc_list_init
            d_t = var_kwd
            for t in range(p.MAX_DOC_LEN - 1, 0, -1):
                w_t = doc_emb[:, t, :]
                y_t = var_doc[:, t - 1]
                llk_step, l1_step, pred, hc_list, d_t = self.one_step_fw(
                    w_t, y_t, hc_list, d_t, self.sc_rnn_bw, self.w_hr_bw,
                    self.w_ho_bw)
                step_loss.append(llk_step)
                step_l1.append(l1_step)

            bw_loss = sum(step_loss)
            bw_l1 = sum(step_l1) * self.config.eta

            #--- BW for learning ---
            #             _loss = (fw_loss + bw_loss)/2.
            #             _l1 = (fw_l1 + bw_l1)/2.
            batch_loss = bw_loss + bw_l1
            batch_loss.backward(retain_graph=True)
            torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                           self.config.clip)
            self.optimizer.step()

            #--- tracking loss ---
            loss_list.append(0.5 * (fw_loss + bw_loss).cpu().data.numpy())
            l1_list.append(0.5 * (fw_l1 + bw_l1).cpu().data.numpy())
            fw_list.append(fw_loss.cpu().data.numpy())
            bw_list.append(bw_loss.cpu().data.numpy())

        return loss_list, l1_list, fw_list, bw_list

    def train(self):
        print('***Start training ...')
        for epoch_i in tqdm(range(self.config.n_epoch),
                            desc='Epoch',
                            dynamic_ncols=True,
                            ascii=True):
            loss_list, l1_list, fw_list, bw_list = self.train_epoch()
            # Save parameters at checkpoint
            if (epoch_i + 1) % self.config.eval_rate == 0:

                #--- Dump model ---
                if self.config.write_model:
                    # save model
                    self.save_checkpoint(
                        {
                            'epoch': epoch_i + 1,
                            'state_dict': self.model.state_dict(),
                            'total_loss': np.mean(loss_list),
                            'optimizer': [self.optimizer.state_dict()],
                        },
                        filename=self.cp_dir +
                        'chk_point_{}.pth'.format(epoch_i + 1))

                #--- Eval each step ---
                if self.config.is_eval:
                    self.evaluate(epoch_i + 1)

            print(
                '\n***Ep-{} | Total_loss: {} [FW/BW {}/{}] | D-L1: {} | NORM: {}'
                .format(epoch_i, np.mean(loss_list), np.mean(fw_list),
                        np.mean(bw_list), np.mean(l1_list),
                        self.get_norm_grad(self.model)))

            #             self.writer.update_parameters(self.model, epoch_i)
            self.writer.update_loss(np.mean(loss_list), epoch_i, 'total_loss')
            self.writer.update_loss(np.mean(l1_list), epoch_i, 'l1_reg')
            self.writer.update_loss(np.mean(fw_list), epoch_i, 'fw_loss')
            self.writer.update_loss(np.mean(bw_list), epoch_i, 'bw_loss')

    def gen_one_step(self, x, hc_list, d_t, rnn_model, w_hr, w_ho):
        with torch.no_grad():
            var_x = Variable(torch.LongTensor(x), requires_grad=False)
            d_t = Variable(d_t, requires_grad=False)
            hc_list = self.to_gpu(hc_list)

        w_t = self.embedding(var_x)
        h_tm1, _ = hc_list
        res_hr = sum(
            [w_hr[l](h_tm1[l]) for l in range(self.config.num_layers)])
        r_t = torch.sigmoid(self.w_wr(w_t) + self.config.alpha * res_hr)
        d_t = r_t * d_t
        flat_h, hc_list = rnn_model(w_t, hc_list, d_t)
        _prob = nn.LogSoftmax(dim=-1)(w_ho(flat_h))
        return _prob.detach().cpu().numpy().squeeze(), self.to_cpu(
            hc_list), d_t.detach().cpu()

    def get_top_index(self, _prob):
        # [b, vocab]
        _prob = np.exp(_prob)
        if self.config.is_sample:
            top_indices = np.random.choice(self.config.vocab_size,
                                           self.config.beam_size,
                                           replace=False,
                                           p=_prob.reshape(-1))
        else:
            top_indices = np.argsort(-_prob)

        return top_indices

    def to_cpu(self, _list):
        return tuple([m.detach().cpu() for m in _list])

    def to_gpu(self, _list):
        return tuple([Variable(m, requires_grad=False) for m in _list])

    def rerank(self, beams, d_t):
        def add_bw_score(w_list, d_t):
            #             import pdb; pdb.set_trace()
            with torch.no_grad():
                hc_list = (torch.zeros(self.config.num_layers, 1,
                                       self.config.hidden_size),
                           torch.zeros(self.config.num_layers, 1,
                                       self.config.hidden_size))
            w_list = [self.w2i[w] for w in w_list[::-1]]
            llk = 0.
            for i, w in enumerate(w_list[:-1]):
                _prob, hc_list, d_t = self.gen_one_step([w], hc_list, d_t,
                                                        self.sc_rnn_bw,
                                                        self.w_hr_bw,
                                                        self.w_ho_bw)
                llk += _prob[w_list[i + 1]]
            return llk / (len(w_list) - 1)

        for i, b in enumerate(beams):
            #             import pdb; pdb.set_trace()
            beams[i] = tuple([0.5 *
                              (b[0] + add_bw_score(b[1], d_t))]) + tuple(b[1:])

        return beams

    def evaluate(self, epoch_i):
        #--- load model ---
        self.load_model(epoch_i)
        self.model.eval()
        for r_id, doc_features in enumerate(
                tqdm(self.test_loader,
                     desc='Test',
                     dynamic_ncols=True,
                     ascii=True)):
            _, d_t = doc_features
            try:
                if torch.sum(d_t) == 0:
                    continue
                #--- Gen 1st step ---
                with torch.no_grad():
                    hc_list = (torch.zeros(self.config.num_layers, 1,
                                           self.config.hidden_size),
                               torch.zeros(self.config.num_layers, 1,
                                           self.config.hidden_size))

                b = (0.0, [self.i2w[1]], [1], hc_list, d_t)
                _prob, hc_list, d_t = self.gen_one_step(
                    b[2], b[3], b[4], self.sc_rnn_fw, self.w_hr_fw,
                    self.w_ho_fw)
                top_indices = self.get_top_index(_prob)
                beam_candidates = []
                for i in range(self.config.beam_size):
                    wordix = top_indices[i]
                    beam_candidates.append(
                        (b[0] + _prob[wordix], b[1] + [self.i2w[wordix]],
                         [wordix], hc_list, d_t))

                #--- Gen the whole sentence ---
                beams = beam_candidates[:self.config.beam_size]
                for t in range(self.config.gen_size - 1):
                    beam_candidates = []
                    for b in beams:
                        _prob, hc_list, d_t = self.gen_one_step(
                            b[2], b[3], b[4], self.sc_rnn_fw, self.w_hr_fw,
                            self.w_ho_fw)
                        top_indices = self.get_top_index(_prob)

                        for i in range(self.config.beam_size):
                            #--- already EOS ---
                            if b[2] == [2]:
                                beam_candidates.append(b)
                                break
                            wordix = top_indices[i]
                            beam_candidates.append((b[0] + _prob[wordix],
                                                    b[1] + [self.i2w[wordix]],
                                                    [wordix], hc_list, d_t))

                    beam_candidates.sort(key=lambda x: x[0] / (len(x[1]) - 1),
                                         reverse=True)  # decreasing order
                    beams = beam_candidates[:self.config.
                                            beam_size]  # truncate to get new beams

                #--- RERANK beams ---
                beams = self.rerank(beams, doc_features[1])
                beams.sort(key=lambda x: x[0], reverse=True)

                res = "[*]EP_{}_KW_[{}]_SENT_[{}]\n".format(
                    epoch_i, ' '.join([
                        self.i2k[int(j)] for j in torch.flatten(
                            torch.nonzero(doc_features[1][0])).numpy()
                    ]), ' '.join(beams[0][1]))
                print(res)
                self.out_file.write(res)
                self.out_file.flush()
            except Exception as e:
                print('Exception: ', str(e))
                pass


#         self.out_file.close()

        self.model.train()
コード例 #16
0
class Solver(object):
    def __init__(self, config=None, train_loader=None, test_loader=None):
        """Class that Builds, Trains and Evaluates AC-SUM-GAN model"""
        self.config = config
        self.train_loader = train_loader
        self.test_loader = test_loader

    def build(self):

        # Build Modules
        self.linear_compress = nn.Linear(self.config.input_size,
                                         self.config.hidden_size).cuda()
        self.summarizer = Summarizer(input_size=self.config.hidden_size,
                                     hidden_size=self.config.hidden_size,
                                     num_layers=self.config.num_layers).cuda()
        self.discriminator = Discriminator(
            input_size=self.config.hidden_size,
            hidden_size=self.config.hidden_size,
            num_layers=self.config.num_layers).cuda()
        self.actor = Actor(state_size=self.config.action_state_size,
                           action_size=self.config.action_state_size).cuda()
        self.critic = Critic(state_size=self.config.action_state_size,
                             action_size=self.config.action_state_size).cuda()
        self.model = nn.ModuleList([
            self.linear_compress, self.summarizer, self.discriminator,
            self.actor, self.critic
        ])

        if self.config.mode == 'train':
            # Build Optimizers
            self.e_optimizer = optim.Adam(
                self.summarizer.vae.e_lstm.parameters(), lr=self.config.lr)
            self.d_optimizer = optim.Adam(
                self.summarizer.vae.d_lstm.parameters(), lr=self.config.lr)
            self.c_optimizer = optim.Adam(
                list(self.discriminator.parameters()) +
                list(self.linear_compress.parameters()),
                lr=self.config.discriminator_lr)
            self.optimizerA_s = optim.Adam(
                list(self.actor.parameters()) +
                list(self.summarizer.s_lstm.parameters()) +
                list(self.linear_compress.parameters()),
                lr=self.config.lr)
            self.optimizerC = optim.Adam(self.critic.parameters(),
                                         lr=self.config.lr)

            self.writer = TensorboardWriter(str(self.config.log_dir))

    def reconstruction_loss(self, h_origin, h_sum):
        """L2 loss between original-regenerated features at cLSTM's last hidden layer"""

        return torch.norm(h_origin - h_sum, p=2)

    def prior_loss(self, mu, log_variance):
        """KL( q(e|x) || N(0,1) )"""
        return 0.5 * torch.sum(-1 + log_variance.exp() + mu.pow(2) -
                               log_variance)

    def sparsity_loss(self, scores):
        """Summary-Length Regularization"""

        return torch.abs(
            torch.mean(scores) - self.config.regularization_factor)

    criterion = nn.MSELoss()

    def AC(self, original_features, seq_len, action_fragments):
        """ Function that makes the actor's actions, in the training steps where the actor and critic components are not trained"""
        scores = self.summarizer.s_lstm(original_features)  # [seq_len, 1]

        fragment_scores = np.zeros(
            self.config.action_state_size)  # [num_fragments, 1]
        for fragment in range(self.config.action_state_size):
            fragment_scores[fragment] = scores[action_fragments[
                fragment, 0]:action_fragments[fragment, 1] + 1].mean()
        state = fragment_scores

        previous_actions = [
        ]  # save all the actions (the selected fragments of each episode)
        reduction_factor = (
            self.config.action_state_size -
            self.config.termination_point) / self.config.action_state_size
        action_scores = (torch.ones(seq_len) * reduction_factor).cuda()
        action_fragment_scores = (torch.ones(
            self.config.action_state_size)).cuda()

        counter = 0
        for ACstep in range(self.config.termination_point):

            state = torch.FloatTensor(state).cuda()
            # select an action
            dist = self.actor(state)
            action = dist.sample(
            )  # returns a scalar between 0-action_state_size

            if action not in previous_actions:
                previous_actions.append(action)
                action_factor = (self.config.termination_point - counter) / (
                    self.config.action_state_size - counter) + 1

                action_scores[action_fragments[action,
                                               0]:action_fragments[action, 1] +
                              1] = action_factor
                action_fragment_scores[action] = 0

                counter = counter + 1

            next_state = state * action_fragment_scores
            next_state = next_state.cpu().detach().numpy()
            state = next_state

        weighted_scores = action_scores.unsqueeze(1) * scores
        weighted_features = weighted_scores.view(-1, 1, 1) * original_features

        return weighted_features, weighted_scores

    def train(self):

        step = 0
        for epoch_i in trange(self.config.n_epochs, desc='Epoch', ncols=80):
            self.model.train()
            recon_loss_init_history = []
            recon_loss_history = []
            sparsity_loss_history = []
            prior_loss_history = []
            g_loss_history = []
            e_loss_history = []
            d_loss_history = []
            c_original_loss_history = []
            c_summary_loss_history = []
            actor_loss_history = []
            critic_loss_history = []
            reward_history = []

            # Train in batches of as many videos as the batch_size
            num_batches = int(len(self.train_loader) / self.config.batch_size)
            iterator = iter(self.train_loader)
            for batch in range(num_batches):
                list_image_features = []
                list_action_fragments = []

                print(f'batch: {batch}')

                # ---- Train eLSTM ----#
                if self.config.verbose:
                    tqdm.write('Training eLSTM...')
                self.e_optimizer.zero_grad()
                for video in range(self.config.batch_size):
                    image_features, action_fragments = next(iterator)

                    action_fragments = action_fragments.squeeze(0)
                    # [batch_size, seq_len, input_size]
                    # [seq_len, input_size]
                    image_features = image_features.view(
                        -1, self.config.input_size)

                    list_image_features.append(image_features)
                    list_action_fragments.append(action_fragments)

                    # [seq_len, input_size]
                    image_features_ = Variable(image_features).cuda()
                    seq_len = image_features_.shape[0]

                    # [seq_len, 1, hidden_size]
                    original_features = self.linear_compress(
                        image_features_.detach()).unsqueeze(1)

                    weighted_features, scores = self.AC(
                        original_features, seq_len, action_fragments)
                    h_mu, h_log_variance, generated_features = self.summarizer.vae(
                        weighted_features)

                    h_origin, original_prob = self.discriminator(
                        original_features)
                    h_sum, sum_prob = self.discriminator(generated_features)

                    if self.config.verbose:
                        tqdm.write(
                            f'original_p: {original_prob.item():.3f}, summary_p: {sum_prob.item():.3f}'
                        )

                    reconstruction_loss = self.reconstruction_loss(
                        h_origin, h_sum)
                    prior_loss = self.prior_loss(h_mu, h_log_variance)

                    tqdm.write(
                        f'recon loss {reconstruction_loss.item():.3f}, prior loss: {prior_loss.item():.3f}'
                    )

                    e_loss = reconstruction_loss + prior_loss
                    e_loss = e_loss / self.config.batch_size
                    e_loss.backward()

                    prior_loss_history.append(prior_loss.data)
                    e_loss_history.append(e_loss.data)

                # Update e_lstm parameters every 'batch_size' iterations
                torch.nn.utils.clip_grad_norm_(
                    self.summarizer.vae.e_lstm.parameters(), self.config.clip)
                self.e_optimizer.step()

                #---- Train dLSTM (decoder/generator) ----#
                if self.config.verbose:
                    tqdm.write('Training dLSTM...')
                self.d_optimizer.zero_grad()
                for video in range(self.config.batch_size):
                    image_features = list_image_features[video]
                    action_fragments = list_action_fragments[video]

                    # [seq_len, input_size]
                    image_features_ = Variable(image_features).cuda()
                    seq_len = image_features_.shape[0]

                    # [seq_len, 1, hidden_size]
                    original_features = self.linear_compress(
                        image_features_.detach()).unsqueeze(1)

                    weighted_features, _ = self.AC(original_features, seq_len,
                                                   action_fragments)
                    h_mu, h_log_variance, generated_features = self.summarizer.vae(
                        weighted_features)

                    h_origin, original_prob = self.discriminator(
                        original_features)
                    h_sum, sum_prob = self.discriminator(generated_features)

                    tqdm.write(
                        f'original_p: {original_prob.item():.3f}, summary_p: {sum_prob.item():.3f}'
                    )

                    reconstruction_loss = self.reconstruction_loss(
                        h_origin, h_sum)
                    g_loss = self.criterion(sum_prob, original_label)

                    orig_features = original_features.squeeze(
                        1)  # [seq_len, hidden_size]
                    gen_features = generated_features.squeeze(1)  #         >>
                    recon_losses = []
                    for frame_index in range(seq_len):
                        recon_losses.append(
                            self.reconstruction_loss(
                                orig_features[frame_index, :],
                                gen_features[frame_index, :]))
                    reconstruction_loss_init = torch.stack(recon_losses).mean()

                    if self.config.verbose:
                        tqdm.write(
                            f'recon loss {reconstruction_loss.item():.3f}, g loss: {g_loss.item():.3f}'
                        )

                    d_loss = reconstruction_loss + g_loss
                    d_loss = d_loss / self.config.batch_size
                    d_loss.backward()

                    recon_loss_init_history.append(
                        reconstruction_loss_init.data)
                    recon_loss_history.append(reconstruction_loss.data)
                    g_loss_history.append(g_loss.data)
                    d_loss_history.append(d_loss.data)

                # Update d_lstm parameters every 'batch_size' iterations
                torch.nn.utils.clip_grad_norm_(
                    self.summarizer.vae.d_lstm.parameters(), self.config.clip)
                self.d_optimizer.step()

                #---- Train cLSTM ----#
                if self.config.verbose:
                    tqdm.write('Training cLSTM...')
                self.c_optimizer.zero_grad()
                for video in range(self.config.batch_size):
                    image_features = list_image_features[video]
                    action_fragments = list_action_fragments[video]

                    # [seq_len, input_size]
                    image_features_ = Variable(image_features).cuda()
                    seq_len = image_features_.shape[0]

                    # Train with original loss
                    # [seq_len, 1, hidden_size]
                    original_features = self.linear_compress(
                        image_features_.detach()).unsqueeze(1)
                    h_origin, original_prob = self.discriminator(
                        original_features)
                    c_original_loss = self.criterion(original_prob,
                                                     original_label)
                    c_original_loss = c_original_loss / self.config.batch_size
                    c_original_loss.backward()

                    # Train with summary loss
                    weighted_features, _ = self.AC(original_features, seq_len,
                                                   action_fragments)
                    h_mu, h_log_variance, generated_features = self.summarizer.vae(
                        weighted_features)
                    h_sum, sum_prob = self.discriminator(
                        generated_features.detach())
                    c_summary_loss = self.criterion(sum_prob, summary_label)
                    c_summary_loss = c_summary_loss / self.config.batch_size
                    c_summary_loss.backward()

                    tqdm.write(
                        f'original_p: {original_prob.item():.3f}, summary_p: {sum_prob.item():.3f}'
                    )

                    c_original_loss_history.append(c_original_loss.data)
                    c_summary_loss_history.append(c_summary_loss.data)

                # Update c_lstm parameters every 'batch_size' iterations
                torch.nn.utils.clip_grad_norm_(
                    list(self.discriminator.parameters()) +
                    list(self.linear_compress.parameters()), self.config.clip)
                self.c_optimizer.step()

                #---- Train sLSTM and actor-critic ----#
                if self.config.verbose:
                    tqdm.write('Training sLSTM, actor and critic...')
                self.optimizerA_s.zero_grad()
                self.optimizerC.zero_grad()
                for video in range(self.config.batch_size):
                    image_features = list_image_features[video]
                    action_fragments = list_action_fragments[video]

                    # [seq_len, input_size]
                    image_features_ = Variable(image_features).cuda()
                    seq_len = image_features_.shape[0]

                    # [seq_len, 1, hidden_size]
                    original_features = self.linear_compress(
                        image_features_.detach()).unsqueeze(1)
                    scores = self.summarizer.s_lstm(
                        original_features)  # [seq_len, 1]

                    fragment_scores = np.zeros(
                        self.config.action_state_size)  # [num_fragments, 1]
                    for fragment in range(self.config.action_state_size):
                        fragment_scores[fragment] = scores[action_fragments[
                            fragment,
                            0]:action_fragments[fragment, 1] + 1].mean()

                    state = fragment_scores  # [action_state_size, 1]

                    previous_actions = [
                    ]  # save all the actions (the selected fragments of each step)
                    reduction_factor = (self.config.action_state_size -
                                        self.config.termination_point
                                        ) / self.config.action_state_size
                    action_scores = (torch.ones(seq_len) *
                                     reduction_factor).cuda()
                    action_fragment_scores = (torch.ones(
                        self.config.action_state_size)).cuda()

                    log_probs = []
                    values = []
                    rewards = []
                    masks = []
                    entropy = 0

                    counter = 0
                    for ACstep in range(self.config.termination_point):
                        # select an action, get a value for the current state
                        state = torch.FloatTensor(
                            state).cuda()  # [action_state_size, 1]
                        dist, value = self.actor(state), self.critic(state)
                        action = dist.sample(
                        )  # returns a scalar between 0-action_state_size

                        if action in previous_actions:

                            reward = 0

                        else:

                            previous_actions.append(action)
                            action_factor = (
                                self.config.termination_point - counter
                            ) / (self.config.action_state_size - counter) + 1

                            action_scores[action_fragments[
                                action, 0]:action_fragments[action, 1] +
                                          1] = action_factor
                            action_fragment_scores[action] = 0

                            weighted_scores = action_scores.unsqueeze(
                                1) * scores
                            weighted_features = weighted_scores.view(
                                -1, 1, 1) * original_features

                            h_mu, h_log_variance, generated_features = self.summarizer.vae(
                                weighted_features)

                            h_origin, original_prob = self.discriminator(
                                original_features)
                            h_sum, sum_prob = self.discriminator(
                                generated_features)

                            tqdm.write(
                                f'original_p: {original_prob.item():.3f}, summary_p: {sum_prob.item():.3f}'
                            )

                            rec_loss = self.reconstruction_loss(
                                h_origin, h_sum)
                            reward = 1 - rec_loss.item(
                            )  # the less the distance, the higher the reward
                            counter = counter + 1

                        next_state = state * action_fragment_scores
                        next_state = next_state.cpu().detach().numpy()

                        log_prob = dist.log_prob(action).unsqueeze(0)
                        entropy += dist.entropy().mean()

                        log_probs.append(log_prob)
                        values.append(value)
                        rewards.append(
                            torch.tensor([reward],
                                         dtype=torch.float,
                                         device=device))

                        if ACstep == self.config.termination_point - 1:
                            masks.append(
                                torch.tensor([0],
                                             dtype=torch.float,
                                             device=device))
                        else:
                            masks.append(
                                torch.tensor([1],
                                             dtype=torch.float,
                                             device=device))

                        state = next_state

                    next_state = torch.FloatTensor(next_state).to(device)
                    next_value = self.critic(next_state)
                    returns = compute_returns(next_value, rewards, masks)

                    log_probs = torch.cat(log_probs)
                    returns = torch.cat(returns).detach()
                    values = torch.cat(values)

                    advantage = returns - values

                    actor_loss = -((log_probs * advantage.detach()).mean() +
                                   (self.config.entropy_coef /
                                    self.config.termination_point) * entropy)
                    sparsity_loss = self.sparsity_loss(scores)
                    critic_loss = advantage.pow(2).mean()

                    actor_loss = actor_loss / self.config.batch_size
                    sparsity_loss = sparsity_loss / self.config.batch_size
                    critic_loss = critic_loss / self.config.batch_size
                    actor_loss.backward()
                    sparsity_loss.backward()
                    critic_loss.backward()

                    reward_mean = torch.mean(torch.stack(rewards))
                    reward_history.append(reward_mean)
                    actor_loss_history.append(actor_loss)
                    sparsity_loss_history.append(sparsity_loss)
                    critic_loss_history.append(critic_loss)

                    if self.config.verbose:
                        tqdm.write('Plotting...')

                    self.writer.update_loss(original_prob.data, step,
                                            'original_prob')
                    self.writer.update_loss(sum_prob.data, step, 'sum_prob')

                    step += 1

                # Update s_lstm, actor and critic parameters every 'batch_size' iterations
                torch.nn.utils.clip_grad_norm_(
                    list(self.actor.parameters()) +
                    list(self.linear_compress.parameters()) +
                    list(self.summarizer.s_lstm.parameters()) +
                    list(self.critic.parameters()), self.config.clip)
                self.optimizerA_s.step()
                self.optimizerC.step()

            recon_loss_init = torch.stack(recon_loss_init_history).mean()
            recon_loss = torch.stack(recon_loss_history).mean()
            prior_loss = torch.stack(prior_loss_history).mean()
            g_loss = torch.stack(g_loss_history).mean()
            e_loss = torch.stack(e_loss_history).mean()
            d_loss = torch.stack(d_loss_history).mean()
            c_original_loss = torch.stack(c_original_loss_history).mean()
            c_summary_loss = torch.stack(c_summary_loss_history).mean()
            sparsity_loss = torch.stack(sparsity_loss_history).mean()
            actor_loss = torch.stack(actor_loss_history).mean()
            critic_loss = torch.stack(critic_loss_history).mean()
            reward = torch.mean(torch.stack(reward_history))

            # Plot
            if self.config.verbose:
                tqdm.write('Plotting...')
            self.writer.update_loss(recon_loss_init, epoch_i,
                                    'recon_loss_init_epoch')
            self.writer.update_loss(recon_loss, epoch_i, 'recon_loss_epoch')
            self.writer.update_loss(prior_loss, epoch_i, 'prior_loss_epoch')
            self.writer.update_loss(g_loss, epoch_i, 'g_loss_epoch')
            self.writer.update_loss(e_loss, epoch_i, 'e_loss_epoch')
            self.writer.update_loss(d_loss, epoch_i, 'd_loss_epoch')
            self.writer.update_loss(c_original_loss, epoch_i,
                                    'c_original_loss_epoch')
            self.writer.update_loss(c_summary_loss, epoch_i,
                                    'c_summary_loss_epoch')
            self.writer.update_loss(sparsity_loss, epoch_i,
                                    'sparsity_loss_epoch')
            self.writer.update_loss(actor_loss, epoch_i, 'actor_loss_epoch')
            self.writer.update_loss(critic_loss, epoch_i, 'critic_loss_epoch')
            self.writer.update_loss(reward, epoch_i, 'reward_epoch')

            # Save parameters at checkpoint
            ckpt_path = str(self.config.save_dir) + f'/epoch-{epoch_i}.pkl'
            if self.config.verbose:
                tqdm.write(f'Save parameters at {ckpt_path}')
            torch.save(self.model.state_dict(), ckpt_path)

            self.evaluate(epoch_i)

    def evaluate(self, epoch_i):

        self.model.eval()

        out_dict = {}

        for image_features, video_name, action_fragments in tqdm(
                self.test_loader, desc='Evaluate', ncols=80, leave=False):
            # [seq_len, batch_size=1, input_size)]
            image_features = image_features.view(-1, self.config.input_size)
            image_features_ = Variable(image_features).cuda()

            # [seq_len, 1, hidden_size]
            original_features = self.linear_compress(
                image_features_.detach()).unsqueeze(1)
            seq_len = original_features.shape[0]

            with torch.no_grad():

                _, scores = self.AC(original_features, seq_len,
                                    action_fragments)

                scores = scores.squeeze(1)
                scores = scores.cpu().numpy().tolist()

                out_dict[video_name] = scores

            score_save_path = self.config.score_dir.joinpath(
                f'{self.config.video_type}_{epoch_i}.json')
            with open(score_save_path, 'w') as f:
                if self.config.verbose:
                    tqdm.write(f'Saving score at {str(score_save_path)}.')
                json.dump(out_dict, f)
            score_save_path.chmod(0o777)
コード例 #17
0
class Learning(object):
    def __init__(self,
            model,
            criterion,
            optimizer,
            scheduler,
            metric_ftns,
            device,
            num_epoch,
            grad_clipping,
            grad_accumulation_steps,
            early_stopping,
            validation_frequency,
            tensorboard,
            checkpoint_dir,
            resume_path):
        self.device, device_ids = self._prepare_device(device)
        # self.model = model.to(self.device)
        
        self.start_epoch = 1
        if resume_path is not None:
            self._resume_checkpoint(resume_path)
        if len(device_ids) > 1:
            # self.model = torch.nn.DataParallel(model, device_ids=device_ids)
            self.model = torch.nn.DataParallel(model)
            # cudnn.benchmark = True
        self.model = model.cuda()
        self.criterion = criterion
        self.metric_ftns = metric_ftns
        self.optimizer = optimizer
        self.num_epoch = num_epoch 
        self.scheduler = scheduler
        self.grad_clipping = grad_clipping
        self.grad_accumulation_steps = grad_accumulation_steps
        self.early_stopping = early_stopping
        self.validation_frequency =validation_frequency
        self.checkpoint_dir = checkpoint_dir
        self.best_epoch = 1
        self.best_score = 0
        self.writer = TensorboardWriter(os.path.join(checkpoint_dir, 'tensorboard'), tensorboard)
        self.train_metrics = MetricTracker('loss', writer = self.writer)
        self.valid_metrics = MetricTracker('loss', *[m.__name__ for m in self.metric_ftns], writer = self.writer)
        
    def train(self, train_dataloader):
        score = 0
        for epoch in range(self.start_epoch, self.num_epoch+1):
            print("{} epoch: \t start training....".format(epoch))
            start = time.time()
            train_result  = self._train_epoch(epoch, train_dataloader)
            train_result.update({'time': time.time()-start})
            
            for key, value in train_result.items():
                print('    {:15s}: {}'.format(str(key), value))

            # if (epoch+1) % self.validation_frequency!=0:
            #     print("skip validation....")
            #     continue
            # print('{} epoch: \t start validation....'.format(epoch))
            # start = time.time()
            # valid_result = self._valid_epoch(epoch, valid_dataloader)
            # valid_result.update({'time': time.time() - start})
            
            # for key, value in valid_result.items():
            #     if 'score' in key:
            #         score = value 
            #     print('   {:15s}: {}'.format(str(key), value))
            score+=0.001
            self.post_processing(score, epoch)
            if epoch - self.best_epoch > self.early_stopping:
                print('WARNING: EARLY STOPPING')
                break
    def _train_epoch(self, epoch, data_loader):
        self.model.train()
        self.optimizer.zero_grad()
        self.train_metrics.reset()
        for idx, (data, target) in enumerate(data_loader):
            data = Variable(data.cuda())
            target = [ann.to(self.device) for ann in target]
            output = self.model(data)
            loss = self.criterion(output, target)
            loss.backward()
            self.writer.set_step((epoch - 1) * len(data_loader) + idx)
            self.train_metrics.update('loss', loss.item())
            if (idx+1) % self.grad_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clipping)
                self.optimizer.step()
                self.optimizer.zero_grad()
            if (idx+1) % int(np.sqrt(len(data_loader))) == 0:
                self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))
        return self.train_metrics.result()
    def _valid_epoch(self, epoch, data_loader):
        self.valid_metrics.reset()
        self.model.eval()
        with torch.no_grad():
            for idx, (data, target) in enumerate(data_loader):
                data, target = data.to(self.device), target.to(self.device)
                output = self.model(data)
                loss = self.criterion(output, target)
                self.writer.set_step((epoch - 1) * len(data_loader) + idx, 'valid')
                self.valid_metrics.update('loss', loss.item())
                for met in self.metric_ftns:
                    self.valid_metrics.update(met.__name__, met(output, target))
                self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))
        
        for name, p in self.model.named_parameters():
            self.writer.add_histogram(name, p, bins='auto')
        
        return self.valid_metrics.result()
    def post_processing(self, score, epoch):
        best = False
        if score > self.best_score:
            self.best_score = score 
            self.best_epoch = epoch 
            best = True
            print("best model: {} epoch - {:.5}".format(epoch, score))
        self._save_checkpoint(epoch = epoch, save_best = best)
        
        if self.scheduler.__class__.__name__ == 'ReduceLROnPlateau':
            self.scheduler.step(score)
        else:
            self.scheduler.step()
    
    def _save_checkpoint(self, epoch, save_best=False):
        """
        Saving checkpoints
        :param epoch: current epoch number
        :param save_best: if True, rename the saved checkpoint to 'model_best.pth'
        """
        arch = type(self.model).__name__
        state = {
            'arch': arch,
            'epoch': epoch,
            'state_dict': self.get_state_dict(self.model),
            'best_score': self.best_score
        }
        filename = os.path.join(self.checkpoint_dir, 'checkpoint_epoch{}.pth'.format(epoch))
        torch.save(state, filename)
        print("Saving checkpoint: {} ...".format(filename))
        if save_best:
            best_path = os.path.join(self.checkpoint_dir, 'model_best.pth')
            torch.save(state, best_path)
            print("Saving current best: model_best.pth ...")
    @staticmethod
    def get_state_dict(model):
        if type(model) == torch.nn.DataParallel:
            state_dict = model.module.state_dict()
        else:
            state_dict = model.state_dict()
        return state_dict
    
    def _resume_checkpoint(self, resume_path):
        resume_path = str(resume_path)
        print("Loading checkpoint: {} ...".format(resume_path))
        checkpoint = torch.load(resume_path, map_location=lambda storage, loc: storage)
        self.start_epoch = checkpoint['epoch'] + 1
        self.best_epoch = checkpoint['epoch']
        self.best_score = checkpoint['best_score']
        self.model.load_state_dict(checkpoint['state_dict'])

        print("Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch))
    
    @staticmethod
    def _prepare_device(device):
        n_gpu_use = len(device)
        n_gpu = torch.cuda.device_count()
        if n_gpu_use > 0 and n_gpu == 0:
            print("Warning: There\'s no GPU available on this machine, training will be performed on CPU.")
            n_gpu_use = 0
        if n_gpu_use > n_gpu:
            print("Warning: The number of GPU\'s configured to use is {}, but only {} are available on this machine.".format(n_gpu_use, n_gpu))
            n_gpu_use = n_gpu
        list_ids = device
        device = torch.device('cuda:{}'.format(device[0]) if n_gpu_use > 0 else 'cpu')
        
        return device, list_ids
コード例 #18
0
class Solver(object):
    def __init__(self,
                 config,
                 train_data_loader,
                 eval_data_loader,
                 vocab,
                 is_train=True,
                 model=None):
        self.config = config
        self.epoch_i = 0
        self.train_data_loader = train_data_loader
        self.eval_data_loader = eval_data_loader
        self.vocab = vocab
        self.is_train = is_train
        self.model = model
        self.writer = None
        self.optimizer = None
        self.epoch_loss = None
        self.validation_loss = None

    def build(self, cuda=True):
        if self.model is None:
            self.model = getattr(models, self.config.model)(self.config)

            if self.config.mode == 'train' and self.config.checkpoint is None:
                print('Parameter initiailization')
                for name, param in self.model.named_parameters():
                    if 'weight_hh' in name:
                        print('\t' + name)
                        nn.init.orthogonal_(param)

                    if 'bias_hh' in name:
                        print('\t' + name)
                        dim = int(param.size(0) / 3)
                        param.data[dim:2 * dim].fill_(2.0)

        if torch.cuda.is_available() and cuda:
            self.model.cuda()

        if self.config.checkpoint:
            self.load_model(self.config.checkpoint)

        if self.is_train:
            self.writer = TensorboardWriter(self.config.logdir)
            self.optimizer = self.config.optimizer(
                filter(lambda p: p.requires_grad, self.model.parameters()),
                lr=self.config.learning_rate)

    def save_model(self, epoch):
        ckpt_path = os.path.join(self.config.save_path, f'{epoch}.pkl')
        print(f'Save parameters to {ckpt_path}')
        torch.save(self.model.state_dict(), ckpt_path)

    def load_model(self, checkpoint):
        print(f'Load parameters from {checkpoint}')
        epoch = re.match(r"[0-9]*", os.path.basename(checkpoint)).group(0)
        self.epoch_i = int(epoch)
        chpt = torch.load(checkpoint)
        new_state_dict = OrderedDict()
        for k, v in chpt.items():
            name = k[7:] if k.startswith(
                "module.") else k  #remove 'module.' of DataParallel
            new_state_dict[name] = v
        self.model.load_state_dict(new_state_dict)

    def write_summary(self, epoch_i):
        epoch_loss = getattr(self, 'epoch_loss', None)
        if epoch_loss is not None:
            self.writer.update_loss(loss=epoch_loss,
                                    step_i=epoch_i + 1,
                                    name='train_loss')

        epoch_recon_loss = getattr(self, 'epoch_recon_loss', None)
        if epoch_recon_loss is not None:
            self.writer.update_loss(loss=epoch_recon_loss,
                                    step_i=epoch_i + 1,
                                    name='train_recon_loss')

        epoch_kl_div = getattr(self, 'epoch_kl_div', None)
        if epoch_kl_div is not None:
            self.writer.update_loss(loss=epoch_kl_div,
                                    step_i=epoch_i + 1,
                                    name='train_kl_div')

        kl_mult = getattr(self, 'kl_mult', None)
        if kl_mult is not None:
            self.writer.update_loss(loss=kl_mult,
                                    step_i=epoch_i + 1,
                                    name='kl_mult')

        epoch_bow_loss = getattr(self, 'epoch_bow_loss', None)
        if epoch_bow_loss is not None:
            self.writer.update_loss(loss=epoch_bow_loss,
                                    step_i=epoch_i + 1,
                                    name='bow_loss')

        validation_loss = getattr(self, 'validation_loss', None)
        if validation_loss is not None:
            self.writer.update_loss(loss=validation_loss,
                                    step_i=epoch_i + 1,
                                    name='validation_loss')

    def train(self):
        raise NotImplementedError

    def evaluate(self):
        raise NotImplementedError

    def test(self):
        raise NotImplementedError

    def export_samples(self, beam_size=5):
        raise NotImplementedError
コード例 #19
0
class Solver(object):
    def __init__(self, config=None, train_loader=None, test_loader=None):
        """Class that Builds, Trains and Evaluates SUM-GAN-sl model"""
        self.config = config
        self.train_loader = train_loader
        self.test_loader = test_loader

    def build(self):

        # Build Modules
        self.linear_compress = nn.Linear(self.config.input_size,
                                         self.config.hidden_size).cuda()
        self.summarizer = Summarizer(input_size=self.config.hidden_size,
                                     hidden_size=self.config.hidden_size,
                                     num_layers=self.config.num_layers).cuda()
        self.discriminator = Discriminator(
            input_size=self.config.hidden_size,
            hidden_size=self.config.hidden_size,
            num_layers=self.config.num_layers).cuda()
        self.model = nn.ModuleList(
            [self.linear_compress, self.summarizer, self.discriminator])

        if self.config.mode == 'train':
            # Build Optimizers
            self.s_e_optimizer = optim.Adam(
                list(self.summarizer.s_lstm.parameters()) +
                list(self.summarizer.vae.e_lstm.parameters()) +
                list(self.linear_compress.parameters()),
                lr=self.config.lr)
            self.d_optimizer = optim.Adam(
                list(self.summarizer.vae.d_lstm.parameters()) +
                list(self.linear_compress.parameters()),
                lr=self.config.lr)
            self.c_optimizer = optim.Adam(
                list(self.discriminator.parameters()) +
                list(self.linear_compress.parameters()),
                lr=self.config.discriminator_lr)

            self.writer = TensorboardWriter(str(self.config.log_dir))

    def reconstruction_loss(self, h_origin, h_sum):
        """L2 loss between original-regenerated features at cLSTM's last hidden layer"""

        return torch.norm(h_origin - h_sum, p=2)

    def prior_loss(self, mu, log_variance):
        """KL( q(e|x) || N(0,1) )"""
        return 0.5 * torch.sum(-1 + log_variance.exp() + mu.pow(2) -
                               log_variance)

    def sparsity_loss(self, scores):
        """Summary-Length Regularization"""

        return torch.abs(
            torch.mean(scores) - self.config.regularization_factor)

    criterion = nn.MSELoss()

    def train(self):
        step = 0
        for epoch_i in trange(self.config.n_epochs, desc='Epoch', ncols=80):
            s_e_loss_history = []
            d_loss_history = []
            c_original_loss_history = []
            c_summary_loss_history = []
            for batch_i, image_features in enumerate(
                    tqdm(self.train_loader,
                         desc='Batch',
                         ncols=80,
                         leave=False)):

                self.model.train()

                # [batch_size=1, seq_len, 1024]
                # [seq_len, 1024]
                image_features = image_features.view(-1,
                                                     self.config.input_size)

                # [seq_len, 1024]
                image_features_ = Variable(image_features).cuda()

                #---- Train sLSTM, eLSTM ----#
                if self.config.verbose:
                    tqdm.write('\nTraining sLSTM and eLSTM...')

                # [seq_len, 1, hidden_size]
                original_features = self.linear_compress(
                    image_features_.detach()).unsqueeze(1)

                scores, h_mu, h_log_variance, generated_features = self.summarizer(
                    original_features)

                h_origin, original_prob = self.discriminator(original_features)
                h_sum, sum_prob = self.discriminator(generated_features)

                tqdm.write(
                    f'original_p: {original_prob.item():.3f}, summary_p: {sum_prob.item():.3f}'
                )

                reconstruction_loss = self.reconstruction_loss(h_origin, h_sum)
                prior_loss = self.prior_loss(h_mu, h_log_variance)
                sparsity_loss = self.sparsity_loss(scores)

                tqdm.write(
                    f'recon loss {reconstruction_loss.item():.3f}, prior loss: {prior_loss.item():.3f}, sparsity loss: {sparsity_loss.item():.3f}'
                )

                s_e_loss = reconstruction_loss + prior_loss + sparsity_loss

                self.s_e_optimizer.zero_grad()
                s_e_loss.backward()
                # Gradient cliping
                torch.nn.utils.clip_grad_norm(self.model.parameters(),
                                              self.config.clip)
                self.s_e_optimizer.step()

                s_e_loss_history.append(s_e_loss.data)

                #---- Train dLSTM (generator) ----#
                if self.config.verbose:
                    tqdm.write('Training dLSTM...')

                # [seq_len, 1, hidden_size]
                original_features = self.linear_compress(
                    image_features_.detach()).unsqueeze(1)

                scores, h_mu, h_log_variance, generated_features = self.summarizer(
                    original_features)

                h_origin, original_prob = self.discriminator(original_features)
                h_sum, sum_prob = self.discriminator(generated_features)

                tqdm.write(
                    f'original_p: {original_prob.item():.3f}, summary_p: {sum_prob.item():.3f}'
                )

                reconstruction_loss = self.reconstruction_loss(h_origin, h_sum)
                g_loss = self.criterion(sum_prob, original_label)

                tqdm.write(
                    f'recon loss {reconstruction_loss.item():.3f}, g loss: {g_loss.item():.3f}'
                )

                d_loss = reconstruction_loss + g_loss

                self.d_optimizer.zero_grad()
                d_loss.backward()
                # Gradient cliping
                torch.nn.utils.clip_grad_norm(self.model.parameters(),
                                              self.config.clip)
                self.d_optimizer.step()

                d_loss_history.append(d_loss.data)

                #---- Train cLSTM ----#
                if self.config.verbose:
                    tqdm.write('Training cLSTM...')

                self.c_optimizer.zero_grad()

                # Train with original loss
                # [seq_len, 1, hidden_size]
                original_features = self.linear_compress(
                    image_features_.detach()).unsqueeze(1)
                h_origin, original_prob = self.discriminator(original_features)
                c_original_loss = self.criterion(original_prob, original_label)
                c_original_loss.backward()

                # Train with summary loss
                scores, h_mu, h_log_variance, generated_features = self.summarizer(
                    original_features)
                h_sum, sum_prob = self.discriminator(
                    generated_features.detach())
                c_summary_loss = self.criterion(sum_prob, summary_label)
                c_summary_loss.backward()

                tqdm.write(
                    f'original_p: {original_prob.item():.3f}, summary_p: {sum_prob.item():.3f}'
                )
                tqdm.write(f'gen loss: {g_loss.item():.3f}')

                # Gradient cliping
                torch.nn.utils.clip_grad_norm(self.model.parameters(),
                                              self.config.clip)
                self.c_optimizer.step()

                c_original_loss_history.append(c_original_loss.data)
                c_summary_loss_history.append(c_summary_loss.data)

                if self.config.verbose:
                    tqdm.write('Plotting...')

                self.writer.update_loss(reconstruction_loss.data, step,
                                        'recon_loss')
                self.writer.update_loss(prior_loss.data, step, 'prior_loss')
                self.writer.update_loss(sparsity_loss.data, step,
                                        'sparsity_loss')
                self.writer.update_loss(g_loss.data, step, 'gen_loss')

                self.writer.update_loss(original_prob.data, step,
                                        'original_prob')
                self.writer.update_loss(sum_prob.data, step, 'sum_prob')

                step += 1

            s_e_loss = torch.stack(s_e_loss_history).mean()
            d_loss = torch.stack(d_loss_history).mean()
            c_original_loss = torch.stack(c_original_loss_history).mean()
            c_summary_loss = torch.stack(c_summary_loss_history).mean()

            # Plot
            if self.config.verbose:
                tqdm.write('Plotting...')
            self.writer.update_loss(s_e_loss, epoch_i, 's_e_loss_epoch')
            self.writer.update_loss(d_loss, epoch_i, 'd_loss_epoch')
            self.writer.update_loss(c_original_loss, step, 'c_original_loss')
            self.writer.update_loss(c_summary_loss, step, 'c_summary_loss')

            # Save parameters at checkpoint
            ckpt_path = str(self.config.save_dir) + f'/epoch-{epoch_i}.pkl'
            tqdm.write(f'Save parameters at {ckpt_path}')
            torch.save(self.model.state_dict(), ckpt_path)

            self.evaluate(epoch_i)

    def evaluate(self, epoch_i):

        self.model.eval()

        out_dict = {}

        for video_tensor, video_name in tqdm(self.test_loader,
                                             desc='Evaluate',
                                             ncols=80,
                                             leave=False):

            # [seq_len, batch=1, 1024]
            video_tensor = video_tensor.view(-1, self.config.input_size)
            video_feature = Variable(video_tensor).cuda()

            # [seq_len, 1, hidden_size]
            video_feature = self.linear_compress(
                video_feature.detach()).unsqueeze(1)

            # [seq_len]
            with torch.no_grad():
                scores = self.summarizer.s_lstm(video_feature).squeeze(1)
                scores = scores.cpu().numpy().tolist()

                out_dict[video_name] = scores

            score_save_path = self.config.score_dir.joinpath(
                f'{self.config.video_type}_{epoch_i}.json')
            with open(score_save_path, 'w') as f:
                tqdm.write(f'Saving score at {str(score_save_path)}.')
                json.dump(out_dict, f)
            score_save_path.chmod(0o777)

    def pretrain(self):
        pass
コード例 #20
0
class Solver(object):
    def __init__(self, config=None, train_loader=None, test_loader=None):
        """Class that Builds, Trains and Evaluates SUM-GAN model"""
        self.config = config
        self.train_loader = train_loader
        self.test_loader = test_loader

    def build(self):

        # Build Modules
        self.linear_compress = nn.Linear(self.config.input_size,
                                         self.config.hidden_size).cuda()
        self.summarizer = Summarizer(input_size=self.config.hidden_size,
                                     hidden_size=self.config.hidden_size,
                                     num_layers=self.config.num_layers).cuda()
        self.discriminator = Discriminator(
            input_size=self.config.hidden_size,
            hidden_size=self.config.hidden_size,
            num_layers=self.config.num_layers).cuda()
        self.model = nn.ModuleList(
            [self.linear_compress, self.summarizer, self.discriminator])

        if self.config.mode == 'train':
            # Build Optimizers
            self.s_e_optimizer = optim.Adam(
                list(self.summarizer.s_lstm.parameters()) +
                list(self.summarizer.vae.e_lstm.parameters()) +
                list(self.linear_compress.parameters()),
                lr=self.config.lr)
            self.d_optimizer = optim.Adam(
                list(self.summarizer.vae.d_lstm.parameters()) +
                list(self.linear_compress.parameters()),
                lr=self.config.lr)
            self.c_optimizer = optim.Adam(
                list(self.discriminator.parameters()) +
                list(self.linear_compress.parameters()),
                lr=self.config.discriminator_lr)

            self.model.train()
            # self.model.apply(apply_weight_norm)

            # Overview Parameters
            # print('Model Parameters')
            # for name, param in self.model.named_parameters():
            #     print('\t' + name + '\t', list(param.size()))

            # Tensorboard
            self.writer = TensorboardWriter(self.config.log_dir)

    @staticmethod
    def freeze_model(module):
        for p in module.parameters():
            p.requires_grad = False

    def reconstruction_loss(self, h_origin, h_fake):
        """L2 loss between original-regenerated features at cLSTM's last hidden layer"""

        return torch.norm(h_origin - h_fake, p=2)

    def prior_loss(self, mu, log_variance):
        """KL( q(e|x) || N(0,1) )"""
        return 0.5 * torch.sum(-1 + log_variance.exp() + mu.pow(2) -
                               log_variance)

    def sparsity_loss(self, scores):
        """Summary-Length Regularization"""

        return torch.abs(torch.mean(scores) - self.config.summary_rate)

    def gan_loss(self, original_prob, fake_prob, uniform_prob):
        """Typical GAN loss + Classify uniformly scored features"""

        gan_loss = torch.mean(
            torch.log(original_prob) + torch.log(1 - fake_prob) +
            torch.log(1 - uniform_prob))  # Discriminate uniform score

        return gan_loss

    def train(self):
        step = 0
        for epoch_i in trange(self.config.n_epochs, desc='Epoch', ncols=80):
            s_e_loss_history = []
            d_loss_history = []
            c_loss_history = []
            for batch_i, image_features in enumerate(
                    tqdm(self.train_loader,
                         desc='Batch',
                         ncols=80,
                         leave=False)):

                if image_features.size(1) > 10000:
                    continue

                # [batch_size=1, seq_len, 2048]
                # [seq_len, 2048]
                image_features = image_features.view(-1,
                                                     self.config.input_size)

                # [seq_len, 2048]
                image_features_ = Variable(image_features).cuda()

                #---- Train sLSTM, eLSTM ----#
                if self.config.verbose:
                    tqdm.write('\nTraining sLSTM and eLSTM...')

                # [seq_len, 1, hidden_size]
                original_features = self.linear_compress(
                    image_features_.detach()).unsqueeze(1)

                scores, h_mu, h_log_variance, generated_features = self.summarizer(
                    original_features)
                _, _, _, uniform_features = self.summarizer(original_features,
                                                            uniform=True)

                h_origin, original_prob = self.discriminator(original_features)
                h_fake, fake_prob = self.discriminator(generated_features)
                h_uniform, uniform_prob = self.discriminator(uniform_features)

                tqdm.write(
                    f'original_p: {original_prob.data[0]:.3f}, fake_p: {fake_prob.data[0]:.3f}, uniform_p: {uniform_prob.data[0]:.3f}'
                )

                reconstruction_loss = self.reconstruction_loss(
                    h_origin, h_fake)
                prior_loss = self.prior_loss(h_mu, h_log_variance)
                sparsity_loss = self.sparsity_loss(scores)

                tqdm.write(
                    f'recon loss {reconstruction_loss.data[0]:.3f}, prior loss: {prior_loss.data[0]:.3f}, sparsity loss: {sparsity_loss.data[0]:.3f}'
                )

                s_e_loss = reconstruction_loss + prior_loss + sparsity_loss

                self.s_e_optimizer.zero_grad()
                s_e_loss.backward()  # retain_graph=True)
                # Gradient cliping
                torch.nn.utils.clip_grad_norm(self.model.parameters(),
                                              self.config.clip)
                self.s_e_optimizer.step()

                s_e_loss_history.append(s_e_loss.data)

                #---- Train dLSTM ----#
                if self.config.verbose:
                    tqdm.write('Training dLSTM...')

                # [seq_len, 1, hidden_size]
                original_features = self.linear_compress(
                    image_features_.detach()).unsqueeze(1)

                scores, h_mu, h_log_variance, generated_features = self.summarizer(
                    original_features)
                _, _, _, uniform_features = self.summarizer(original_features,
                                                            uniform=True)

                h_origin, original_prob = self.discriminator(original_features)
                h_fake, fake_prob = self.discriminator(generated_features)
                h_uniform, uniform_prob = self.discriminator(uniform_features)

                tqdm.write(
                    f'original_p: {original_prob.data[0]:.3f}, fake_p: {fake_prob.data[0]:.3f}, uniform_p: {uniform_prob.data[0]:.3f}'
                )

                reconstruction_loss = self.reconstruction_loss(
                    h_origin, h_fake)
                gan_loss = self.gan_loss(original_prob, fake_prob,
                                         uniform_prob)

                tqdm.write(
                    f'recon loss {reconstruction_loss.data[0]:.3f}, gan loss: {gan_loss.data[0]:.3f}'
                )

                d_loss = reconstruction_loss + gan_loss

                self.d_optimizer.zero_grad()
                d_loss.backward()  # retain_graph=True)
                # Gradient cliping
                torch.nn.utils.clip_grad_norm(self.model.parameters(),
                                              self.config.clip)
                self.d_optimizer.step()

                d_loss_history.append(d_loss.data)

                #---- Train cLSTM ----#
                if batch_i > self.config.discriminator_slow_start:
                    if self.config.verbose:
                        tqdm.write('Training cLSTM...')
                    # [seq_len, 1, hidden_size]
                    original_features = self.linear_compress(
                        image_features_.detach()).unsqueeze(1)

                    scores, h_mu, h_log_variance, generated_features = self.summarizer(
                        original_features)
                    _, _, _, uniform_features = self.summarizer(
                        original_features, uniform=True)

                    h_origin, original_prob = self.discriminator(
                        original_features)
                    h_fake, fake_prob = self.discriminator(generated_features)
                    h_uniform, uniform_prob = self.discriminator(
                        uniform_features)
                    tqdm.write(
                        f'original_p: {original_prob.data[0]:.3f}, fake_p: {fake_prob.data[0]:.3f}, uniform_p: {uniform_prob.data[0]:.3f}'
                    )

                    # Maximization
                    c_loss = -1 * self.gan_loss(original_prob, fake_prob,
                                                uniform_prob)

                    tqdm.write(f'gan loss: {gan_loss.data[0]:.3f}')

                    self.c_optimizer.zero_grad()
                    c_loss.backward()
                    # Gradient cliping
                    torch.nn.utils.clip_grad_norm(self.model.parameters(),
                                                  self.config.clip)
                    self.c_optimizer.step()

                    c_loss_history.append(c_loss.data)

                if self.config.verbose:
                    tqdm.write('Plotting...')

                self.writer.update_loss(reconstruction_loss.data, step,
                                        'recon_loss')
                self.writer.update_loss(prior_loss.data, step, 'prior_loss')
                self.writer.update_loss(sparsity_loss.data, step,
                                        'sparsity_loss')
                self.writer.update_loss(gan_loss.data, step, 'gan_loss')

                # self.writer.update_loss(s_e_loss.data, step, 's_e_loss')
                # self.writer.update_loss(d_loss.data, step, 'd_loss')
                # self.writer.update_loss(c_loss.data, step, 'c_loss')

                self.writer.update_loss(original_prob.data, step,
                                        'original_prob')
                self.writer.update_loss(fake_prob.data, step, 'fake_prob')
                self.writer.update_loss(uniform_prob.data, step,
                                        'uniform_prob')

                step += 1

            s_e_loss = torch.stack(s_e_loss_history).mean()
            d_loss = torch.stack(d_loss_history).mean()
            c_loss = torch.stack(c_loss_history).mean()

            # Plot
            if self.config.verbose:
                tqdm.write('Plotting...')
            self.writer.update_loss(s_e_loss, epoch_i, 's_e_loss_epoch')
            self.writer.update_loss(d_loss, epoch_i, 'd_loss_epoch')
            self.writer.update_loss(c_loss, epoch_i, 'c_loss_epoch')

            # Save parameters at checkpoint
            ckpt_path = str(self.config.save_dir) + f'_epoch-{epoch_i}.pkl'
            tqdm.write(f'Save parameters at {ckpt_path}')
            torch.save(self.model.state_dict(), ckpt_path)

            self.evaluate(epoch_i)

            self.model.train()

    def evaluate(self, epoch_i):
        # checkpoint = self.config.ckpt_path
        # print(f'Load parameters from {checkpoint}')
        # self.model.load_state_dict(torch.load(checkpoint))

        self.model.eval()

        out_dict = {}

        for video_tensor, video_name in tqdm(self.test_loader,
                                             desc='Evaluate',
                                             ncols=80,
                                             leave=False):

            # [seq_len, batch=1, 2048]
            video_tensor = video_tensor.view(-1, self.config.input_size)
            video_feature = Variable(video_tensor, volatile=True).cuda()

            # [seq_len, 1, hidden_size]
            video_feature = self.linear_compress(
                video_feature.detach()).unsqueeze(1)

            # [seq_len]
            scores = self.summarizer.s_lstm(video_feature).squeeze(1)

            scores = np.array(scores.data).tolist()

            out_dict[video_name] = scores

            score_save_path = self.config.score_dir.joinpath(
                f'{self.config.video_type}_{epoch_i}.json')
            with open(score_save_path, 'w') as f:
                tqdm.write(f'Saving score at {str(score_save_path)}.')
                json.dump(out_dict, f)
            score_save_path.chmod(0o777)

    def pretrain(self):
        pass
コード例 #21
0
ファイル: train.py プロジェクト: NoAchache/TextBoxGan
    def __init__(self):

        self.batch_size = cfg.batch_size
        self.strategy = cfg.strategy
        self.max_steps = cfg.max_steps
        self.summary_steps_frequency = cfg.summary_steps_frequency
        self.image_summary_step_frequency = cfg.image_summary_step_frequency
        self.save_step_frequency = cfg.save_step_frequency
        self.log_dir = cfg.log_dir

        self.validation_step_frequency = cfg.validation_step_frequency
        self.tensorboard_writer = TensorboardWriter(self.log_dir)
        # set optimizer params
        self.g_opt = self.update_optimizer_params(cfg.g_opt)
        self.d_opt = self.update_optimizer_params(cfg.d_opt)
        self.pl_mean = tf.Variable(
            initial_value=0.0,
            name="pl_mean",
            trainable=False,
            synchronization=tf.VariableSynchronization.ON_READ,
            aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
        )
        self.training_data_loader = TrainingDataLoader()
        self.validation_data_loader = ValidationDataLoader("validation_corpus.txt")
        self.model_loader = ModelLoader()
        # create model: model and optimizer must be created under `strategy.scope`
        (
            self.discriminator,
            self.generator,
            self.g_clone,
        ) = self.model_loader.initiate_models()

        # set optimizers
        self.d_optimizer = tf.keras.optimizers.Adam(
            self.d_opt["learning_rate"],
            beta_1=self.d_opt["beta1"],
            beta_2=self.d_opt["beta2"],
            epsilon=self.d_opt["epsilon"],
        )
        self.g_optimizer = tf.keras.optimizers.Adam(
            self.g_opt["learning_rate"],
            beta_1=self.g_opt["beta1"],
            beta_2=self.g_opt["beta2"],
            epsilon=self.g_opt["epsilon"],
        )
        self.ocr_optimizer = tf.keras.optimizers.Adam(
            self.g_opt["learning_rate"],
            beta_1=self.g_opt["beta1"],
            beta_2=self.g_opt["beta2"],
            epsilon=self.g_opt["epsilon"],
        )
        self.ocr_loss_weight = cfg.ocr_loss_weight

        self.aster_ocr = AsterInferer()

        self.training_step = TrainingStep(
            self.generator,
            self.discriminator,
            self.aster_ocr,
            self.g_optimizer,
            self.ocr_optimizer,
            self.d_optimizer,
            self.g_opt["reg_interval"],
            self.d_opt["reg_interval"],
            self.pl_mean,
        )

        self.validation_step = ValidationStep(self.g_clone, self.aster_ocr)

        self.manager = self.model_loader.load_checkpoint(
            ckpt_kwargs={
                "d_optimizer": self.d_optimizer,
                "g_optimizer": self.g_optimizer,
                "ocr_optimizer": self.ocr_optimizer,
                "discriminator": self.discriminator,
                "generator": self.generator,
                "g_clone": self.g_clone,
                "pl_mean": self.pl_mean,
            },
            model_description="Full model",
            expect_partial=False,
            ckpt_dir=cfg.ckpt_dir,
            max_to_keep=cfg.num_ckpts_to_keep,
        )
コード例 #22
0
import os
import yaml
from loguru import logger
from time import gmtime, strftime
from utils import TensorboardWriter

if not os.path.isdir('logs'):
    os.mkdir('logs')
current_time = strftime("%Y-%m-%d_%H:%M:%S", gmtime())
logger.add(f'logs/train_{current_time}.log')

# problem on macOS
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

with open('config.yaml') as f:
    config = yaml.safe_load(f)

tensorboard_writer = None
if 'USE_TENSORBOARD' in config and config['USE_TENSORBOARD']:
    tensorboard_writer = TensorboardWriter(f'runs/{current_time}')

logger.info(f'config loaded: {config}')
コード例 #23
0
class Trainer:
    """
    Training pipeline

    Parameters
    ----------
    num_epochs : int
        We should train the model for __ epochs

    start_epoch : int
        We should start training the model from __th epoch

    train_loader : DataLoader
        DataLoader for training data

    model : nn.Module
        Model

    model_name : str
        Name of the model

    loss_function : nn.Module
        Loss function (cross entropy)

    optimizer : optim.Optimizer
        Optimizer (Adam)

    lr_decay : float
        A factor in interval (0, 1) to multiply the learning rate with

    dataset_name : str
        Name of the dataset

    word_map : Dict[str, int]
        Word2id map

    grad_clip : float, optional
        Gradient threshold in clip gradients

    print_freq : int
        Print training status every __ batches

    checkpoint_path : str, optional
        Path to the folder to save checkpoints

    checkpoint_basename : str, optional, default='checkpoint'
        Basename of the checkpoint

    tensorboard : bool, optional, default=False
        Enable tensorboard or not?

    log_dir : str, optional
        Path to the folder to save logs for tensorboard
    """
    def __init__(self,
                 num_epochs: int,
                 start_epoch: int,
                 train_loader: DataLoader,
                 model: nn.Module,
                 model_name: str,
                 loss_function: nn.Module,
                 optimizer,
                 lr_decay: float,
                 dataset_name: str,
                 word_map: Dict[str, int],
                 grad_clip=Optional[None],
                 print_freq: int = 100,
                 checkpoint_path: Optional[str] = None,
                 checkpoint_basename: str = 'checkpoint',
                 tensorboard: bool = False,
                 log_dir: Optional[str] = None) -> None:
        self.num_epochs = num_epochs
        self.start_epoch = start_epoch
        self.train_loader = train_loader

        self.model = model
        self.model_name = model_name
        self.loss_function = loss_function
        self.optimizer = optimizer
        self.lr_decay = lr_decay

        self.dataset_name = dataset_name
        self.word_map = word_map
        self.print_freq = print_freq
        self.grad_clip = grad_clip

        self.checkpoint_path = checkpoint_path
        self.checkpoint_basename = checkpoint_basename

        # setup visualization writer instance
        self.writer = TensorboardWriter(log_dir, tensorboard)
        self.len_epoch = len(self.train_loader)

    def train(self, epoch: int) -> None:
        """
        Train an epoch

        Parameters
        ----------
        epoch : int
            Current number of epoch
        """
        self.model.train()  # training mode enables dropout

        batch_time = AverageMeter(
        )  # forward prop. + back prop. time per batch
        data_time = AverageMeter()  # data loading time per batch
        losses = AverageMeter(tag='loss',
                              writer=self.writer)  # cross entropy loss
        accs = AverageMeter(tag='acc', writer=self.writer)  # accuracies

        start = time.time()

        # batches
        for i, batch in enumerate(self.train_loader):
            data_time.update(time.time() - start)

            if self.model_name in ['han']:
                documents, sentences_per_document, words_per_sentence, labels = batch

                documents = documents.to(
                    device)  # (batch_size, sentence_limit, word_limit)
                sentences_per_document = sentences_per_document.squeeze(1).to(
                    device)  # (batch_size)
                words_per_sentence = words_per_sentence.to(
                    device)  # (batch_size, sentence_limit)
                labels = labels.squeeze(1).to(device)  # (batch_size)

                # forward
                scores, _, _ = self.model(
                    documents, sentences_per_document, words_per_sentence
                )  # (n_documents, n_classes), (n_documents, max_doc_len_in_batch, max_sent_len_in_batch), (n_documents, max_doc_len_in_batch)

            else:
                sentences, words_per_sentence, labels = batch

                sentences = sentences.to(device)  # (batch_size, word_limit)
                words_per_sentence = words_per_sentence.squeeze(1).to(
                    device)  # (batch_size)
                labels = labels.squeeze(1).to(device)  # (batch_size)

                # for torchtext
                # sentences = batch.text[0].to(device)  # (batch_size, word_limit)
                # words_per_sentence = batch.text[1].to(device)  # (batch_size)
                # labels = batch.label.to(device)  # (batch_size)

                scores = self.model(
                    sentences, words_per_sentence)  # (batch_size, n_classes)

            # calc loss
            loss = self.loss_function(scores, labels)  # scalar

            # backward
            self.optimizer.zero_grad()
            loss.backward()

            # clip gradients
            if self.grad_clip is not None:
                clip_gradient(self.optimizer, grad_clip)

            # update weights
            self.optimizer.step()

            # find accuracy
            _, predictions = scores.max(dim=1)  # (n_documents)
            correct_predictions = torch.eq(predictions, labels).sum().item()
            accuracy = correct_predictions / labels.size(0)

            # set step for tensorboard
            step = (epoch - 1) * self.len_epoch + i
            self.writer.set_step(step=step, mode='train')

            # keep track of metrics
            batch_time.update(time.time() - start)
            losses.update(loss.item(), labels.size(0))
            accs.update(accuracy, labels.size(0))

            start = time.time()

            # print training status
            if i % self.print_freq == 0:
                print(
                    'Epoch: [{0}][{1}/{2}]\t'
                    'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                    'Data Load Time {data_time.val:.3f} ({data_time.avg:.3f})\t'
                    'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                    'Accuracy {acc.val:.3f} ({acc.avg:.3f})'.format(
                        epoch,
                        i,
                        len(self.train_loader),
                        batch_time=batch_time,
                        data_time=data_time,
                        loss=losses,
                        acc=accs))

    def run_train(self):
        start = time.time()

        # epochs
        for epoch in range(self.start_epoch, self.num_epochs):
            # trian an epoch
            self.train(epoch=epoch)

            # time per epoch
            epoch_time = time.time() - start
            print('Epoch: [{0}] finished, time consumed: {epoch_time:.3f}'.
                  format(epoch, epoch_time=epoch_time))

            # decay learning rate every epoch
            adjust_learning_rate(self.optimizer, self.lr_decay)

            # save checkpoint
            if self.checkpoint_path is not None:
                save_checkpoint(epoch=epoch,
                                model=self.model,
                                model_name=self.model_name,
                                optimizer=self.optimizer,
                                dataset_name=self.dataset_name,
                                word_map=self.word_map,
                                checkpoint_path=self.checkpoint_path,
                                checkpoint_basename=self.checkpoint_basename)

            start = time.time()
コード例 #24
0
ファイル: solver.py プロジェクト: NoSyu/CDMM-B
class Solver(object):
    def __init__(self,
                 config,
                 train_data_loader,
                 eval_data_loader,
                 is_train=True,
                 model=None):
        self.config = config
        self.epoch_i = 0
        self.train_data_loader = train_data_loader
        self.eval_data_loader = eval_data_loader
        self.is_train = is_train
        self.model = model

    @time_desc_decorator('Build Graph')
    def build(self, cuda=True):
        if self.model is None:
            self.model = getattr(models, self.config.model)(self.config)

            if self.config.mode == 'train' and self.config.checkpoint is None:
                print('Parameter initiailization')
                for name, param in self.model.named_parameters():
                    if 'weight_hh' in name:
                        print('\t' + name)
                        nn.init.orthogonal_(param)

                    if 'bias_hh' in name:
                        print('\t' + name)
                        dim = int(param.size(0) / 3)
                        param.data[dim:2 * dim].fill_(2.0)

        if torch.cuda.is_available() and cuda:
            self.model.cuda()

        print('Model Parameters')
        for name, param in self.model.named_parameters():
            print('\t' + name + '\t', list(param.size()))

        if self.config.checkpoint:
            self.load_model(self.config.checkpoint)

        if self.is_train:
            self.writer = TensorboardWriter(self.config.logdir)
            if self.config.optimizer is None:
                # AdamW
                no_decay = ['bias', 'LayerNorm.weight']
                optimizer_grouped_parameters = [{
                    'params': [
                        p for n, p in self.model.named_parameters()
                        if not any(nd in n for nd in no_decay)
                    ],
                    'weight_decay':
                    0.01
                }, {
                    'params': [
                        p for n, p in self.model.named_parameters()
                        if any(nd in n for nd in no_decay)
                    ],
                    'weight_decay':
                    0.0
                }]
                self.optimizer = AdamW(optimizer_grouped_parameters,
                                       lr=self.config.learning_rate)
            else:
                self.optimizer = self.config.optimizer(
                    filter(lambda p: p.requires_grad, self.model.parameters()),
                    lr=self.config.learning_rate)

    def save_model(self, epoch):
        ckpt_path = os.path.join(self.config.save_path, f'{epoch}.pkl')
        print(f'Save parameters to {ckpt_path}')
        torch.save(self.model.state_dict(), ckpt_path)

    def load_model(self, checkpoint):
        print(f'Load parameters from {checkpoint}')
        epoch = re.match(r"[0-9]*", os.path.basename(checkpoint)).group(0)
        self.epoch_i = int(epoch)
        self.model.load_state_dict(torch.load(checkpoint))

    def write_summary(self, epoch_i):
        train_acc = getattr(self, 'train_acc', None)
        if train_acc is not None:
            self.writer.update_loss(loss=train_acc,
                                    step_i=epoch_i + 1,
                                    name='train_acc')

        validation_acc = getattr(self, 'validation_acc', None)
        if validation_acc is not None:
            self.writer.update_loss(loss=validation_acc,
                                    step_i=epoch_i + 1,
                                    name='validation_acc')

    def train(self):
        raise NotImplementedError

    def evaluate(self):
        raise NotImplementedError

    def test(self, is_print=True):
        raise NotImplementedError

    def _calc_accuracy(self, x, y):
        max_vals, max_indices = torch.max(x, 1)
        train_acc = (max_indices
                     == y).sum().data.cpu().numpy() / max_indices.size()[0]

        return train_acc
コード例 #25
0
class Solver(object):
    def __init__(self,
                 config,
                 train_data_loader,
                 eval_data_loader,
                 is_train=True,
                 model=None):
        self.config = config
        self.epoch_i = 0
        self.train_data_loader = train_data_loader
        self.eval_data_loader = eval_data_loader
        self.is_train = is_train
        self.model = model
        self.writer = None
        self.optimizer = None
        self.epoch_loss = None
        self.validation_loss = None
        self.true_scores = 0
        self.false_scores = 0
        self.eval_epoch_loss = 0

    def build(self, cuda=True):
        if self.model is None:
            self.model = getattr(models, self.config.model)(self.config)

        if torch.cuda.is_available() and cuda:
            self.model.cuda()

        if self.config.checkpoint:
            self.load_model(self.config.checkpoint)

        if self.is_train:
            self.writer = TensorboardWriter(self.config.logdir)
            self.optimizer = self.config.optimizer(
                filter(lambda p: p.requires_grad, self.model.parameters()),
                lr=self.config.learning_rate)

    def save_model(self, epoch):
        ckpt_path = os.path.join(self.config.save_path, f'{epoch}.pkl')
        print(f'Save parameters to {ckpt_path}')
        torch.save(self.model.state_dict(), ckpt_path)

    def load_model(self, checkpoint):
        print(f'Load parameters from {checkpoint}')
        epoch = re.match(r"[0-9]*", os.path.basename(checkpoint)).group(0)
        self.epoch_i = int(epoch)
        self.model.load_state_dict(torch.load(checkpoint))

    def write_summary(self, epoch_i):
        epoch_loss = getattr(self, 'epoch_loss', None)
        if epoch_loss is not None:
            self.writer.update_loss(loss=epoch_loss,
                                    step_i=epoch_i + 1,
                                    name='train_loss')

        raise NotImplementedError

    def train(self):
        raise NotImplementedError

    def evaluate(self):
        raise NotImplementedError

    def test(self):
        raise NotImplementedError
コード例 #26
0
def main():
    config = get_train_config()

    # device
    device, device_ids = setup_device(config.n_gpu)

    # tensorboard
    writer = TensorboardWriter(config.summary_dir, config.tensorboard)

    # metric tracker
    metric_names = ['loss', 'acc1', 'acc5']
    train_metrics = MetricTracker(*[metric for metric in metric_names],
                                  writer=writer)
    valid_metrics = MetricTracker(*[metric for metric in metric_names],
                                  writer=writer)

    # create model
    print("create model")
    model = VisionTransformer(image_size=(config.image_size,
                                          config.image_size),
                              patch_size=(config.patch_size,
                                          config.patch_size),
                              emb_dim=config.emb_dim,
                              mlp_dim=config.mlp_dim,
                              num_heads=config.num_heads,
                              num_layers=config.num_layers,
                              num_classes=config.num_classes,
                              attn_dropout_rate=config.attn_dropout_rate,
                              dropout_rate=config.dropout_rate)

    # load checkpoint
    if config.checkpoint_path:
        state_dict = load_checkpoint(config.checkpoint_path)
        if config.num_classes != state_dict['classifier.weight'].size(0):
            del state_dict['classifier.weight']
            del state_dict['classifier.bias']
            print("re-initialize fc layer")
            model.load_state_dict(state_dict, strict=False)
        else:
            model.load_state_dict(state_dict)
        print("Load pretrained weights from {}".format(config.checkpoint_path))

    # send model to device
    model = model.to(device)
    if len(device_ids) > 1:
        model = torch.nn.DataParallel(model, device_ids=device_ids)

    # create dataloader
    print("create dataloaders")
    train_dataloader = eval("{}DataLoader".format(config.dataset))(
        data_dir=os.path.join(config.data_dir, config.dataset),
        image_size=config.image_size,
        batch_size=config.batch_size,
        num_workers=config.num_workers,
        split='train')
    valid_dataloader = eval("{}DataLoader".format(config.dataset))(
        data_dir=os.path.join(config.data_dir, config.dataset),
        image_size=config.image_size,
        batch_size=config.batch_size,
        num_workers=config.num_workers,
        split='val')

    # training criterion
    print("create criterion and optimizer")
    criterion = nn.CrossEntropyLoss()

    # create optimizers and learning rate scheduler
    optimizer = torch.optim.SGD(params=model.parameters(),
                                lr=config.lr,
                                weight_decay=config.wd,
                                momentum=0.9)
    lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer=optimizer,
        max_lr=config.lr,
        pct_start=config.warmup_steps / config.train_steps,
        total_steps=config.train_steps)

    # start training
    print("start training")
    best_acc = 0.0
    epochs = config.train_steps // len(train_dataloader)
    for epoch in range(1, epochs + 1):
        log = {'epoch': epoch}

        # train the model
        model.train()
        result = train_epoch(epoch, model, train_dataloader, criterion,
                             optimizer, lr_scheduler, train_metrics, device)
        log.update(result)

        # validate the model
        model.eval()
        result = valid_epoch(epoch, model, valid_dataloader, criterion,
                             valid_metrics, device)
        log.update(**{'val_' + k: v for k, v in result.items()})

        # best acc
        best = False
        if log['val_acc1'] > best_acc:
            best_acc = log['val_acc1']
            best = True

        # save model
        save_model(config.checkpoint_dir, epoch, model, optimizer,
                   lr_scheduler, device_ids, best)

        # print logged informations to the screen
        for key, value in log.items():
            print('    {:15s}: {}'.format(str(key), value))
コード例 #27
0
class Solver(object):
    def __init__(self,
                 config,
                 train_data_loader,
                 eval_data_loader,
                 vocab,
                 is_train=True,
                 model=None):
        self.config = config
        self.epoch_i = 0
        self.train_data_loader = train_data_loader
        self.eval_data_loader = eval_data_loader
        self.vocab = vocab
        self.is_train = is_train
        self.model = model

    @time_desc_decorator('Build Graph')
    def build(self, cuda=True):

        if self.model is None:
            self.model = getattr(models, self.config.model)(self.config)

            # orthogonal initialiation for hidden weights
            # input gate bias for GRUs
            if self.config.mode == 'train' and self.config.checkpoint is None:
                print('Parameter initiailization')
                for name, param in self.model.named_parameters():
                    if 'weight_hh' in name:
                        print('\t' + name)
                        nn.init.orthogonal_(param)

                    # bias_hh is concatenation of reset, input, new gates
                    # only set the input gate bias to 2.0
                    if 'bias_hh' in name:
                        print('\t' + name)
                        dim = int(param.size(0) / 3)
                        param.data[dim:2 * dim].fill_(2.0)

        if torch.cuda.is_available() and cuda:
            self.model.cuda()

        # Overview Parameters
        print('Model Parameters')
        for name, param in self.model.named_parameters():
            print('\t' + name + '\t', list(param.size()))

        if self.config.checkpoint:
            self.load_model(self.config.checkpoint)

        if self.is_train:
            self.writer = TensorboardWriter(self.config.logdir)
            self.optimizer = self.config.optimizer(
                filter(lambda p: p.requires_grad, self.model.parameters()),
                lr=self.config.learning_rate)

    def save_model(self, epoch):
        """Save parameters to checkpoint"""
        ckpt_path = os.path.join(self.config.save_path, f'{epoch}.pkl')
        print(f'Save parameters to {ckpt_path}')
        torch.save(self.model.state_dict(), ckpt_path)

    def load_model(self, checkpoint):
        """Load parameters from checkpoint"""
        print(f'Load parameters from {checkpoint}')
        epoch = re.match(r"[0-9]*", os.path.basename(checkpoint)).group(0)
        self.epoch_i = int(epoch)
        self.model.load_state_dict(torch.load(checkpoint))

    def write_summary(self, epoch_i):
        epoch_loss = getattr(self, 'epoch_loss', None)
        if epoch_loss is not None:
            self.writer.update_loss(loss=epoch_loss,
                                    step_i=epoch_i + 1,
                                    name='train_loss')

        epoch_recon_loss = getattr(self, 'epoch_recon_loss', None)
        if epoch_recon_loss is not None:
            self.writer.update_loss(loss=epoch_recon_loss,
                                    step_i=epoch_i + 1,
                                    name='train_recon_loss')

        epoch_kl_div = getattr(self, 'epoch_kl_div', None)
        if epoch_kl_div is not None:
            self.writer.update_loss(loss=epoch_kl_div,
                                    step_i=epoch_i + 1,
                                    name='train_kl_div')

        kl_mult = getattr(self, 'kl_mult', None)
        if kl_mult is not None:
            self.writer.update_loss(loss=kl_mult,
                                    step_i=epoch_i + 1,
                                    name='kl_mult')

        epoch_bow_loss = getattr(self, 'epoch_bow_loss', None)
        if epoch_bow_loss is not None:
            self.writer.update_loss(loss=epoch_bow_loss,
                                    step_i=epoch_i + 1,
                                    name='bow_loss')

        validation_loss = getattr(self, 'validation_loss', None)
        if validation_loss is not None:
            self.writer.update_loss(loss=validation_loss,
                                    step_i=epoch_i + 1,
                                    name='validation_loss')

        average_bleu = getattr(self, "average_bleu", None)
        if average_bleu is not None:
            self.writer.update_loss(loss=average_bleu,
                                    step_i=epoch_i + 1,
                                    name='average_bleu')

        average_sequences = getattr(self, "average_sequences", None)
        if average_sequences is not None:
            self.writer.update_loss(loss=average_sequences,
                                    step_i=epoch_i + 1,
                                    name='average_sequences')

        average_levenshteins = getattr(self, "average_levenshteins", None)
        if average_levenshteins is not None:
            self.writer.update_loss(loss=average_levenshteins,
                                    step_i=epoch_i + 1,
                                    name='average_levenshteins')

    @time_desc_decorator('Training Start!')
    def train(self):
        epoch_loss_history = []
        for epoch_i in range(self.epoch_i, self.config.n_epoch):
            self.epoch_i = epoch_i
            batch_loss_history = []
            self.model.train()
            n_total_words = 0
            for batch_i, (conversations, conversation_length,
                          sentence_length) in enumerate(
                              tqdm(self.train_data_loader, ncols=80)):
                # conversations: (batch_size) list of conversations
                #   conversation: list of sentences
                #   sentence: list of tokens
                # conversation_length: list of int
                # sentence_length: (batch_size) list of conversation list of sentence_lengths

                input_conversations = [conv[:-1] for conv in conversations]
                target_conversations = [conv[1:] for conv in conversations]

                # flatten input and target conversations
                input_sentences = [
                    sent for conv in input_conversations for sent in conv
                ]
                target_sentences = [
                    sent for conv in target_conversations for sent in conv
                ]
                input_sentence_length = [
                    l for len_list in sentence_length for l in len_list[:-1]
                ]
                target_sentence_length = [
                    l for len_list in sentence_length for l in len_list[1:]
                ]
                input_conversation_length = [
                    l - 1 for l in conversation_length
                ]

                input_sentences = to_var(torch.LongTensor(input_sentences))
                target_sentences = to_var(torch.LongTensor(target_sentences))
                input_sentence_length = to_var(
                    torch.LongTensor(input_sentence_length))
                target_sentence_length = to_var(
                    torch.LongTensor(target_sentence_length))
                input_conversation_length = to_var(
                    torch.LongTensor(input_conversation_length))

                # reset gradient
                self.optimizer.zero_grad()

                sentence_logits = self.model(input_sentences,
                                             input_sentence_length,
                                             input_conversation_length,
                                             target_sentences,
                                             decode=False)

                batch_loss, n_words = masked_cross_entropy(
                    sentence_logits, target_sentences, target_sentence_length)

                assert not isnan(batch_loss.item())
                batch_loss_history.append(batch_loss.item())
                n_total_words += n_words.item()

                if batch_i % self.config.print_every == 0:
                    tqdm.write(
                        f'Epoch: {epoch_i+1}, iter {batch_i}: loss = {batch_loss.item()/ n_words.item():.3f}'
                    )

                # Back-propagation
                batch_loss.backward()

                # Gradient cliping
                torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                               self.config.clip)

                # Run optimizer
                self.optimizer.step()

            epoch_loss = np.sum(batch_loss_history) / n_total_words
            epoch_loss_history.append(epoch_loss)
            self.epoch_loss = epoch_loss

            print_str = f'Epoch {epoch_i+1} loss average: {epoch_loss:.3f}'
            print(print_str)

            if epoch_i % self.config.save_every_epoch == 0:
                self.save_model(epoch_i + 1)

            print('\n<Validation>...')
            self.validation_loss = self.evaluate()

            if epoch_i % self.config.plot_every_epoch == 0:
                self.write_summary(epoch_i)

        self.save_model(self.config.n_epoch)

        return epoch_loss_history

    def generate_sentence(self, input_sentences, input_sentence_length,
                          input_conversation_length, target_sentences):
        self.model.eval()

        # [batch_size, max_seq_len, vocab_size]
        generated_sentences = self.model(input_sentences,
                                         input_sentence_length,
                                         input_conversation_length,
                                         target_sentences,
                                         decode=True)

        # write output to file
        with open(os.path.join(self.config.save_path, 'samples.txt'),
                  'a') as f:
            f.write(f'<Epoch {self.epoch_i}>\n\n')

            tqdm.write('\n<Samples>')
            for input_sent, target_sent, output_sent in zip(
                    input_sentences, target_sentences, generated_sentences):
                input_sent = self.vocab.decode(input_sent)
                target_sent = self.vocab.decode(target_sent)
                output_sent = '\n'.join(
                    [self.vocab.decode(sent) for sent in output_sent])
                s = '\n'.join([
                    'Input sentence: ' + input_sent,
                    'Ground truth: ' + target_sent,
                    'Generated response: ' + output_sent + '\n'
                ])
                f.write(s + '\n')
                print(s)
            print('')

    def evaluate(self):
        self.model.eval()
        batch_loss_history = []
        n_total_words = 0
        for batch_i, (conversations, conversation_length,
                      sentence_length) in enumerate(
                          tqdm(self.eval_data_loader, ncols=80)):
            # conversations: (batch_size) list of conversations
            #   conversation: list of sentences
            #   sentence: list of tokens
            # conversation_length: list of int
            # sentence_length: (batch_size) list of conversation list of sentence_lengths

            input_conversations = [conv[:-1] for conv in conversations]
            target_conversations = [conv[1:] for conv in conversations]

            # flatten input and target conversations
            input_sentences = [
                sent for conv in input_conversations for sent in conv
            ]
            target_sentences = [
                sent for conv in target_conversations for sent in conv
            ]
            input_sentence_length = [
                l for len_list in sentence_length for l in len_list[:-1]
            ]
            target_sentence_length = [
                l for len_list in sentence_length for l in len_list[1:]
            ]
            input_conversation_length = [l - 1 for l in conversation_length]

            with torch.no_grad():
                input_sentences = to_var(torch.LongTensor(input_sentences))
                target_sentences = to_var(torch.LongTensor(target_sentences))
                input_sentence_length = to_var(
                    torch.LongTensor(input_sentence_length))
                target_sentence_length = to_var(
                    torch.LongTensor(target_sentence_length))
                input_conversation_length = to_var(
                    torch.LongTensor(input_conversation_length))

            if batch_i == 0:
                self.generate_sentence(input_sentences, input_sentence_length,
                                       input_conversation_length,
                                       target_sentences)

            sentence_logits = self.model(input_sentences,
                                         input_sentence_length,
                                         input_conversation_length,
                                         target_sentences)

            batch_loss, n_words = masked_cross_entropy(sentence_logits,
                                                       target_sentences,
                                                       target_sentence_length)

            assert not isnan(batch_loss.item())
            batch_loss_history.append(batch_loss.item())
            n_total_words += n_words.item()

        epoch_loss = np.sum(batch_loss_history) / n_total_words

        print_str = f'Validation loss: {epoch_loss:.3f}\n'
        print(print_str)

        return epoch_loss

    def test(self):
        self.model.eval()
        batch_loss_history = []
        n_total_words = 0
        for batch_i, (conversations, conversation_length,
                      sentence_length) in enumerate(
                          tqdm(self.eval_data_loader, ncols=80)):
            # conversations: (batch_size) list of conversations
            #   conversation: list of sentences
            #   sentence: list of tokens
            # conversation_length: list of int
            # sentence_length: (batch_size) list of conversation list of sentence_lengths

            input_conversations = [conv[:-1] for conv in conversations]
            target_conversations = [conv[1:] for conv in conversations]

            # flatten input and target conversations
            input_sentences = [
                sent for conv in input_conversations for sent in conv
            ]
            target_sentences = [
                sent for conv in target_conversations for sent in conv
            ]
            input_sentence_length = [
                l for len_list in sentence_length for l in len_list[:-1]
            ]
            target_sentence_length = [
                l for len_list in sentence_length for l in len_list[1:]
            ]
            input_conversation_length = [l - 1 for l in conversation_length]

            with torch.no_grad():
                input_sentences = to_var(torch.LongTensor(input_sentences))
                target_sentences = to_var(torch.LongTensor(target_sentences))
                input_sentence_length = to_var(
                    torch.LongTensor(input_sentence_length))
                target_sentence_length = to_var(
                    torch.LongTensor(target_sentence_length))
                input_conversation_length = to_var(
                    torch.LongTensor(input_conversation_length))

            sentence_logits = self.model(input_sentences,
                                         input_sentence_length,
                                         input_conversation_length,
                                         target_sentences)

            batch_loss, n_words = masked_cross_entropy(sentence_logits,
                                                       target_sentences,
                                                       target_sentence_length)

            assert not isnan(batch_loss.item())
            batch_loss_history.append(batch_loss.item())
            n_total_words += n_words.item()

        epoch_loss = np.sum(batch_loss_history) / n_total_words

        print(f'Number of words: {n_total_words}')
        print(f'Bits per word: {epoch_loss:.3f}')
        word_perplexity = np.exp(epoch_loss)

        print_str = f'Word perplexity : {word_perplexity:.3f}\n'
        print(print_str)

        return word_perplexity

    def embedding_metric(self):
        word2vec = getattr(self, 'word2vec', None)
        if word2vec is None:
            print('Loading word2vec model')
            word2vec = gensim.models.KeyedVectors.load_word2vec_format(
                word2vec_path, binary=True)
            self.word2vec = word2vec
        keys = word2vec.vocab
        self.model.eval()
        n_context = self.config.n_context
        n_sample_step = self.config.n_sample_step
        metric_average_history = []
        metric_extrema_history = []
        metric_greedy_history = []
        context_history = []
        sample_history = []
        n_sent = 0
        n_conv = 0
        for batch_i, (conversations, conversation_length, sentence_length) \
                in enumerate(tqdm(self.eval_data_loader, ncols=80)):
            # conversations: (batch_size) list of conversations
            #   conversation: list of sentences
            #   sentence: list of tokens
            # conversation_length: list of int
            # sentence_length: (batch_size) list of conversation list of sentence_lengths

            conv_indices = [
                i for i in range(len(conversations))
                if len(conversations[i]) >= n_context + n_sample_step
            ]
            context = [
                c for i in conv_indices
                for c in [conversations[i][:n_context]]
            ]
            ground_truth = [
                c for i in conv_indices for c in
                [conversations[i][n_context:n_context + n_sample_step]]
            ]
            sentence_length = [
                c for i in conv_indices
                for c in [sentence_length[i][:n_context]]
            ]

            with torch.no_grad():
                context = to_var(torch.LongTensor(context))
                sentence_length = to_var(torch.LongTensor(sentence_length))

            samples = self.model.generate(context, sentence_length, n_context)

            context = context.data.cpu().numpy().tolist()
            samples = samples.data.cpu().numpy().tolist()
            context_history.append(context)
            sample_history.append(samples)

            samples = [[self.vocab.decode(sent) for sent in c]
                       for c in samples]
            ground_truth = [[self.vocab.decode(sent) for sent in c]
                            for c in ground_truth]

            samples = [sent for c in samples for sent in c]
            ground_truth = [sent for c in ground_truth for sent in c]

            samples = [[word2vec[s] for s in sent.split() if s in keys]
                       for sent in samples]
            ground_truth = [[word2vec[s] for s in sent.split() if s in keys]
                            for sent in ground_truth]

            indices = [
                i
                for i, s, g in zip(range(len(samples)), samples, ground_truth)
                if s != [] and g != []
            ]
            samples = [samples[i] for i in indices]
            ground_truth = [ground_truth[i] for i in indices]
            n = len(samples)
            n_sent += n

            metric_average = embedding_metric(samples, ground_truth, word2vec,
                                              'average')
            metric_extrema = embedding_metric(samples, ground_truth, word2vec,
                                              'extrema')
            metric_greedy = embedding_metric(samples, ground_truth, word2vec,
                                             'greedy')
            metric_average_history.append(metric_average)
            metric_extrema_history.append(metric_extrema)
            metric_greedy_history.append(metric_greedy)

        epoch_average = np.mean(np.concatenate(metric_average_history), axis=0)
        epoch_extrema = np.mean(np.concatenate(metric_extrema_history), axis=0)
        epoch_greedy = np.mean(np.concatenate(metric_greedy_history), axis=0)

        print('n_sentences:', n_sent)
        print_str = f'Metrics - Average: {epoch_average:.3f}, Extrema: {epoch_extrema:.3f}, Greedy: {epoch_greedy:.3f}'
        print(print_str)
        print('\n')

        return epoch_average, epoch_extrema, epoch_greedy
コード例 #28
0
class Trainer:
    """
    Encoder-decoder pipeline. Tearcher Forcing is used during training and validation.

    Parameters
    ----------
    caption_model : str
        Type of the caption model

    epochs : int
        We should train the model for __ epochs

    device : torch.device
        Use GPU or not

    word_map : Dict[str, int]
        Word2id map

    rev_word_map : Dict[int, str]
        Id2word map

    start_epoch : int
        We should start training the model from __th epoch

    epochs_since_improvement : int
        Number of epochs since last improvement in BLEU-4 score

    best_bleu4 : float
        Best BLEU-4 score until now

    train_loader : DataLoader
        DataLoader for training data

    val_loader : DataLoader
        DataLoader for validation data

    encoder : nn.Module
        Encoder (based on CNN)

    decoder : nn.Module
        Decoder (based on LSTM)

    encoder_optimizer : optim.Optimizer
        Optimizer for encoder (Adam) (if fine-tune)

    decoder_optimizer : optim.Optimizer
        Optimizer for decoder (Adam)

    loss_function : nn.Module
        Loss function (cross entropy)

    grad_clip : float
        Gradient threshold in clip gradients

    tau : float
        Penalty term τ for doubly stochastic attention in paper: show, attend and tell

    fine_tune_encoder : bool
        Fine-tune encoder or not

    tensorboard : bool, optional, default=False
        Enable tensorboard or not?

    log_dir : str, optional
        Path to the folder to save logs for tensorboard
    """
    def __init__(
        self,
        caption_model: str,
        epochs: int,
        device: torch.device,
        word_map: Dict[str, int],
        rev_word_map: Dict[int, str],
        start_epoch: int,
        epochs_since_improvement: int,
        best_bleu4: float,
        train_loader: DataLoader,
        val_loader: DataLoader,
        encoder: nn.Module,
        decoder: nn.Module,
        encoder_optimizer: optim.Optimizer,
        decoder_optimizer: optim.Optimizer,
        loss_function: nn.Module,
        grad_clip: float,
        tau: float,
        fine_tune_encoder: bool,
        tensorboard: bool = False,
        log_dir: Optional[str] = None
    ) -> None:
        self.device = device  # GPU / CPU

        self.caption_model = caption_model
        self.epochs = epochs
        self.word_map = word_map
        self.rev_word_map = rev_word_map

        self.start_epoch = start_epoch
        self.epochs_since_improvement = epochs_since_improvement
        self.best_bleu4 = best_bleu4

        self.train_loader =  train_loader
        self.val_loader = val_loader
        self.encoder = encoder
        self.decoder = decoder
        self.encoder_optimizer = encoder_optimizer
        self.decoder_optimizer = decoder_optimizer
        self.loss_function = loss_function

        self.tau = tau
        self.grad_clip = grad_clip
        self.fine_tune_encoder = fine_tune_encoder

        self.print_freq = 100  # print training/validation stats every __ batches
        # setup visualization writer instance
        self.writer = TensorboardWriter(log_dir, tensorboard)
        self.len_epoch = len(self.train_loader)

    def train(self, epoch: int) -> None:
        """
        Train an epoch

        Parameters
        ----------
        epoch : int
            Current number of epoch
        """
        self.decoder.train()  # train mode (dropout and batchnorm is used)
        self.encoder.train()

        batch_time = AverageMeter()  # forward prop. + back prop. time
        data_time = AverageMeter()  # data loading time
        losses = AverageMeter(tag='loss', writer=self.writer)  # loss (per word decoded)
        top5accs = AverageMeter(tag='top5acc', writer=self.writer)  # top5 accuracy

        start = time.time()

        # batches
        for i, (imgs, caps, caplens) in enumerate(self.train_loader):
            data_time.update(time.time() - start)

            # Move to GPU, if available
            imgs = imgs.to(self.device)
            caps = caps.to(self.device)
            caplens = caplens.to(self.device)

            # forward encoder
            imgs = self.encoder(imgs)

            # forward decoder
            if self.caption_model == 'att2all':
                scores, caps_sorted, decode_lengths, alphas, sort_ind = self.decoder(imgs, caps, caplens)
            else:
                scores, caps_sorted, decode_lengths, sort_ind = self.decoder(imgs, caps, caplens)

            # since we decoded starting with <start>, the targets are all words after <start>, up to <end>
            targets = caps_sorted[:, 1:]

            # remove timesteps that we didn't decode at, or are pads
            # pack_padded_sequence is an easy trick to do this
            scores = pack_padded_sequence(scores, decode_lengths, batch_first=True)[0]
            targets = pack_padded_sequence(targets, decode_lengths, batch_first=True)[0]

            # calc loss
            loss = self.loss_function(scores, targets)

            # doubly stochastic attention regularization (in paper: show, attend and tell)
            if self.caption_model == 'att2all':
                loss += self.tau * ((1. - alphas.sum(dim = 1)) ** 2).mean()

            # clear gradient of last batch
            self.decoder_optimizer.zero_grad()
            if self.encoder_optimizer is not None:
                self.encoder_optimizer.zero_grad()

            # backward
            loss.backward()

            # clip gradients
            if self.grad_clip is not None:
                clip_gradient(self.decoder_optimizer, self.grad_clip)
                if self.encoder_optimizer is not None:
                    clip_gradient(self.encoder_optimizer, self.grad_clip)

            # update weights
            self.decoder_optimizer.step()
            if self.encoder_optimizer is not None:
                self.encoder_optimizer.step()

            # set step for tensorboard
            step = (epoch - 1) * self.len_epoch + i
            self.writer.set_step(step=step, mode='train')

            # keep track of metrics
            top5 = accuracy(scores, targets, 5)
            losses.update(loss.item(), sum(decode_lengths))
            top5accs.update(top5, sum(decode_lengths))
            batch_time.update(time.time() - start)

            start = time.time()

            # print status
            if i % self.print_freq == 0:
                print(
                    'Epoch: [{0}][{1}/{2}]\t'
                    'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                    'Data Load Time {data_time.val:.3f} ({data_time.avg:.3f})\t'
                    'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                    'Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})'.format(
                        epoch, i, len(self.train_loader),
                        batch_time = batch_time,
                        data_time = data_time,
                        loss = losses,
                        top5 = top5accs
                    )
                )

    def validate(self) -> float:
        """
        Validate an epoch.

        Returns
        -------
        bleu4 : float
            BLEU-4 score
        """
        self.decoder.eval()  # eval mode (no dropout or batchnorm)
        if self.encoder is not None:
            self.encoder.eval()

        batch_time = AverageMeter()
        losses = AverageMeter()
        top5accs = AverageMeter()

        start = time.time()

        ground_truth = list()  # ground_truth (true captions) for calculating BLEU-4 score
        prediction = list()  # prediction (predicted captions)

        # explicitly disable gradient calculation to avoid CUDA memory error
        # solves the issue #57
        with torch.no_grad():
            # Batches
            for i, (imgs, caps, caplens, allcaps) in enumerate(self.val_loader):

                # move to device, if available
                imgs = imgs.to(self.device)
                caps = caps.to(self.device)
                caplens = caplens.to(self.device)

                # forward encoder
                if self.encoder is not None:
                    imgs = self.encoder(imgs)

                # forward decoder
                if self.caption_model == 'att2all':
                    scores, caps_sorted, decode_lengths, alphas, sort_ind = self.decoder(imgs, caps, caplens)
                else:
                    scores, caps_sorted, decode_lengths, sort_ind = self.decoder(imgs, caps, caplens)

                # since we decoded starting with <start>, the targets are all words after <start>, up to <end>
                targets = caps_sorted[:, 1:]

                # remove timesteps that we didn't decode at, or are pads
                # pack_padded_sequence is an easy trick to do this
                scores_copy = scores.clone()
                scores = pack_padded_sequence(scores, decode_lengths, batch_first = True)[0]
                targets = pack_padded_sequence(targets, decode_lengths, batch_first = True)[0]

                # calc loss
                loss = self.loss_function(scores, targets)

                # doubly stochastic attention regularization (in paper: show, attend and tell)
                if self.caption_model == 'att2all':
                    loss += self.tau * ((1. - alphas.sum(dim = 1)) ** 2).mean()

                # keep track of metrics
                losses.update(loss.item(), sum(decode_lengths))
                top5 = accuracy(scores, targets, 5)
                top5accs.update(top5, sum(decode_lengths))
                batch_time.update(time.time() - start)

                start = time.time()

                if i % self.print_freq == 0:
                    print('Validation: [{0}/{1}]\t'
                        'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                        'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                        'Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})\t'.format(i, len(self.val_loader),
                                                                                    batch_time = batch_time,
                                                                                    loss = losses,
                                                                                    top5 = top5accs)
                    )

                # store ground truth captions and predicted captions of each image
                # for n images, each of them has one prediction and multiple ground truths (a, b, c...):
                # prediction = [ [hyp1], [hyp2], ..., [hypn] ]
                # ground_truth = [ [ [ref1a], [ref1b], [ref1c] ], ..., [ [refna], [refnb] ] ]

                # ground truth
                allcaps = allcaps[sort_ind]  # because images were sorted in the decoder
                for j in range(allcaps.shape[0]):
                    img_caps = allcaps[j].tolist()
                    img_captions = list(
                        map(
                            lambda c: [w for w in c if w not in {self.word_map['<start>'], self.word_map['<pad>']}],
                            img_caps
                        )
                    )  # remove <start> and pads
                    ground_truth.append(img_captions)

                # prediction
                _, preds = torch.max(scores_copy, dim = 2)
                preds = preds.tolist()
                temp_preds = list()
                for j, p in enumerate(preds):
                    temp_preds.append(preds[j][:decode_lengths[j]])  # remove pads
                preds = temp_preds
                prediction.extend(preds)

                assert len(ground_truth) == len(prediction)

            # calc BLEU-4 and CIDEr score
            metrics = Metrics(ground_truth, prediction, self.rev_word_map)
            bleu4 = metrics.belu[3]  # BLEU-4
            cider = metrics.cider  # CIDEr

            print(
                '\n * LOSS - {loss.avg:.3f}, TOP-5 ACCURACY - {top5.avg:.3f}, BLEU-4 - {bleu}, CIDEr - {cider}\n'.format(
                    loss = losses,
                    top5 = top5accs,
                    bleu = bleu4,
                    cider = cider
                )
            )

        return bleu4

    def run_train(self) -> None:
        # epochs
        for epoch in range(self.start_epoch, self.epochs):

            # decay learning rate if there is no improvement for 8 consecutive epochs
            # terminate training if there is no improvement for 20 consecutive epochs
            if self.epochs_since_improvement == 20:
                break
            if self.epochs_since_improvement > 0 and self.epochs_since_improvement % 8 == 0:
                adjust_learning_rate(self.decoder_optimizer, 0.8)
                if self.fine_tune_encoder:
                    adjust_learning_rate(self.encoder_optimizer, 0.8)

            # train an epoch
            self.train(epoch = epoch)

            # validate an epoch
            recent_bleu4 = self.validate()

            # epochs num since last improvement
            is_best = recent_bleu4 > self.best_bleu4
            self.best_bleu4 = max(recent_bleu4, self.best_bleu4)
            if not is_best:
                self.epochs_since_improvement += 1
                print("\nEpochs since last improvement: %d\n" % (self.epochs_since_improvement,))
            else:
                self.epochs_since_improvement = 0

            # save checkpoint
            save_checkpoint(
                epoch = epoch,
                epochs_since_improvement = self.epochs_since_improvement,
                encoder = self.encoder,
                decoder = self.decoder,
                encoder_optimizer = self.encoder_optimizer,
                decoder_optimizer = self.decoder_optimizer,
                caption_model = self.caption_model,
                bleu4 = recent_bleu4,
                is_best = is_best
            )