示例#1
0
    def run(self, block_name: str = None) -> bool:
        """Run the evaluation.

        Returns
        ------
        bool
            Whether the component should continue running.

        """
        self.model.to(self.device)
        self.model.eval()
        with torch.no_grad():
            preds, targets = [], []

            for batch in self._eval_iterator:
                pred, target = self.model(*[t.to(self.device) for t in batch])
                preds.append(pred.cpu())
                targets.append(target.cpu())

            preds = torch.cat(preds, dim=0)  # type: ignore
            targets = torch.cat(targets, dim=0)  # type: ignore
            self.eval_metric = self.metric_fn(preds, targets).item()

            tb_prefix = f"{self.tb_log_prefix} " if self.tb_log_prefix else ""

            log(
                f'{tb_prefix}Eval {self.metric_fn}',  # type: ignore
                self.eval_metric,
                global_step=0)  # type: ignore

        continue_ = False  # Single step so don't continue
        return continue_
示例#2
0
文件: eval.py 项目: yukw777/flambe
    def run(self, block_name: str = None) -> bool:
        """Run the evaluation.

        Returns
        ------
        bool
            Whether the component should continue running.

        """
        self.model.to(self.device)
        self.model.eval()
        with torch.no_grad():
            metric_state: Dict = {}

            for batch in self._eval_iterator:
                pred, target = self.model(*[t.to(self.device) for t in batch])
                self.metric_fn.aggregate(metric_state, pred, target)

            self.eval_metric = self.metric_fn.finalize(metric_state)

            tb_prefix = f"{self.tb_log_prefix} " if self.tb_log_prefix else ""

            log(
                f'{tb_prefix}Eval/{self.metric_fn}',  # type: ignore
                self.eval_metric,
                global_step=0)  # type: ignore

        return False
示例#3
0
    def _eval_step(self) -> None:

        if self.teacher_translator is not None and self._step == 0:
            self.teacher_translator.initialize(self.teacher)

            iterator = self.train_sampler.sample(self.dataset.train, 1)

            batch_size = self.train_sampler.batch_size
            drop_last = self.train_sampler.drop_last
            shuffle = self.train_sampler.shuffle

            srcs = []
            new_contexts = []
            new_words = []

            for batch in iterator:

                batch = (t.to(self.device) for t in batch)
                src, tgt_context, tgt_words = batch

                with torch.no_grad():
                    new_tgt, new_src = self.teacher_translator(src, src)
                new_tgt_context = new_tgt[:, :-1]
                new_tgt_words = new_tgt[:, 1:]

                srcs.append(new_src[:, :-1].cpu())
                new_contexts.append(new_tgt_context.cpu())
                new_words.append(new_tgt_words.cpu())

            srcs = torch.cat(srcs, dim=0)
            new_contexts = torch.cat(new_contexts, dim=0)
            new_words = torch.cat(new_words, dim=0)

            original_data = (srcs, new_contexts, new_words)
            train_sets = [original_data]

            self.train_sampler = TensorSampler(train_sets,
                                               probs=[1.0],
                                               batch_size=batch_size,
                                               drop_last=drop_last,
                                               shuffle=shuffle)
            self._create_train_iterator()

        super()._eval_step()

        tb_prefix = f"{self.tb_log_prefix} " if self.tb_log_prefix else ""
        # Log beta
        if not self.use_iter:
            self.beta = self.get_beta(self._step)
            log(f'{tb_prefix}Training/Beta', self.beta, self._step)
示例#4
0
    def _log_metrics(log_prefix: str, metrics_with_states: List[Tuple],
                     global_step: int) -> None:
        """Logs all provided metrics

        Iterates through the provided list of metrics with states,
        finalizes the metric, and logs it.

        Parameters
        ----------
        log_prefix: str
            A string, such as a tensorboard prefix
        metrics_with_states: List[Tuple[Metric, Dict]]
            a list of metric-state tuples
        global_step: int
            the global step for loggin
        """
        for metric, state in metrics_with_states:
            log(f'{log_prefix}/{metric}', metric.finalize(state), global_step)
示例#5
0
    def _train_step(self) -> None:
        """Run a training step over the training data."""
        self.model.train()

        tb_prefix = f"{self.tb_log_prefix} " if self.tb_log_prefix else ""

        with torch.enable_grad():
            for i in range(self.iter_per_step):
                # Zero the gradients and clear the accumulated loss
                self.optimizer.zero_grad()
                accumulated_loss = 0.0
                for _ in range(self.batches_per_iter):
                    # Get next batch
                    try:
                        batch = next(self._train_iterator)
                    except StopIteration:
                        self._create_train_iterator()
                        batch = next(self._train_iterator)
                    batch = self._batch_to_device(batch)

                    # Compute loss
                    loss = self._compute_loss(batch) / self.batches_per_iter
                    accumulated_loss += loss.item()
                    loss.backward()

                # Log loss
                global_step = (self.iter_per_step * self._step) + i

                # Clip gradients if necessary
                if self.max_grad_norm:
                    clip_grad_norm_(self.model.parameters(),
                                    self.max_grad_norm)
                if self.max_grad_abs_val:
                    clip_grad_value_(self.model.parameters(),
                                     self.max_grad_abs_val)

                log(f'{tb_prefix}Training/Loss', accumulated_loss, global_step)
                log(f'{tb_prefix}Training/Gradient_Norm',
                    self.model.gradient_norm, global_step)
                log(f'{tb_prefix}Training/Parameter_Norm',
                    self.model.parameter_norm, global_step)

                # Optimize
                self.optimizer.step()

                # Update iter scheduler
                if self.iter_scheduler is not None:
                    learning_rate = self.iter_scheduler.get_lr()[
                        0]  # type: ignore
                    log(f'{tb_prefix}Training/LR', learning_rate, global_step)
                    self.iter_scheduler.step()  # type: ignore

            # Zero the gradients when exiting a train step
                self.optimizer.zero_grad()
示例#6
0
文件: train.py 项目: jwohlwend/flambe
    def _train_step(self) -> None:
        """Run a training step over the training data."""
        self.model.train()

        tb_prefix = f"{self.tb_log_prefix} " if self.tb_log_prefix else ""

        with torch.enable_grad():
            for i in range(self.iter_per_step):
                # Zero the gradients and clear the accumulated loss
                self.optimizer.zero_grad()
                accumulated_loss = 0.0
                for _ in range(self.batches_per_iter):
                    # Get next batch
                    batch = next(self._train_iterator)
                    batch = self._batch_to_device(batch)

                    # Compute loss
                    loss = self._compute_loss(batch) / self.batches_per_iter
                    accumulated_loss += loss.item()
                    loss.backward()

                # Log loss
                global_step = (self.iter_per_step * self._step) + i

                log(f'{tb_prefix}Training/Loss', accumulated_loss, global_step)
                log(f'{tb_prefix}Training/Gradient_Norm', self.model.gradient_norm, global_step)
                log(f'{tb_prefix}Training/Parameter_Norm', self.model.parameter_norm, global_step)

                # Optimize
                self.optimizer.step()

            # Zero the gradients when exiting a train step
            self.optimizer.zero_grad()
示例#7
0
    def _eval_step(self) -> None:
        """Run an evaluation step over the validation data."""
        self.model.eval()
        metric_fn_state: Dict[Metric, Dict] = {}
        metrics_with_states: List[Tuple] = \
            [(metric, {}) for metric in self.validation_metrics]

        # Initialize a 1-epoch iteration through the validation set
        val_iterator = self.val_sampler.sample(self.dataset.val)

        with torch.no_grad():
            loss = []
            for batch in val_iterator:
                _, _, batch_loss = self._compute_batch(
                    batch,
                    [(self.metric_fn, metric_fn_state), *metrics_with_states])
                loss.append(batch_loss.item())
            val_loss = np.NaN if loss == [] else sum(loss) / len(loss)
            val_metric = self.metric_fn.finalize(metric_fn_state)

        # Update best model
        sign = (-1)**(self.lower_is_better)
        if self._best_metric is None or (sign * val_metric >
                                         sign * self._best_metric):
            self._best_metric = val_metric
            best_model_state = self.model.state_dict()
            for k, t in best_model_state.items():
                best_model_state[k] = t.cpu().detach()
            self._best_model = best_model_state

        # Update scheduler
        if self.scheduler is not None:
            if isinstance(self.scheduler, ReduceLROnPlateau):
                self.scheduler.step(val_loss)
            else:
                # torch's _LRScheduler.step DOES have a default value
                # so passing in no args is fine; it will automatically
                # compute the current epoch
                self.scheduler.step()  # type: ignore

        tb_prefix = f"{self.tb_log_prefix} " if self.tb_log_prefix else ""

        # Log metrics
        log(f'{tb_prefix}Validation/Loss', val_loss, self._step)
        log(f'{tb_prefix}Validation/{self.metric_fn}', val_metric, self._step)
        log(f'{tb_prefix}Best/{self.metric_fn}', self._best_metric,
            self._step)  # type: ignore
        for (metric, state) in metrics_with_states:
            log(f'{tb_prefix}Validation/{metric}', metric.finalize(state),
                self._step)  # type: ignore
示例#8
0
    def _eval_step(self) -> None:
        """Run an evaluation step over the validation data."""
        self.model.eval()

        # Initialize a 1-epoch iteration through the validation set
        val_iterator = self.val_sampler.sample(self.dataset.val)

        with torch.no_grad():

            preds, targets = self._aggregate_preds(val_iterator)
            val_loss = self.loss_fn(preds, targets).item()
            val_metric = self.metric_fn(preds, targets).item()

        # Update best model
        sign = (-1)**(self.lower_is_better)
        if self._best_metric is None or (sign * val_metric >
                                         sign * self._best_metric):
            self._best_metric = val_metric
            best_model_state = self.model.state_dict()
            for k, t in best_model_state.items():
                best_model_state[k] = t.cpu().detach()
            self._best_model = best_model_state

        # Update scheduler
        if self.scheduler is not None:
            if isinstance(self.scheduler, ReduceLROnPlateau):
                self.scheduler.step(val_loss)
            else:
                # torch's _LRScheduler.step DOES have a default value
                # so passing in no args is fine; it will automatically
                # compute the current epoch
                self.scheduler.step()  # type: ignore

        tb_prefix = f"{self.tb_log_prefix} " if self.tb_log_prefix else ""

        # Log metrics
        log(f'{tb_prefix}Validation/Loss', val_loss, self._step)
        log(f'{tb_prefix}Validation/{self.metric_fn}', val_metric, self._step)
        log(f'{tb_prefix}Best/{self.metric_fn}', self._best_metric,
            self._step)  # type: ignore
        for metric_name, metric in self.extra_validation_metrics.items():
            log(f'{tb_prefix}Validation/{metric_name}',
                metric(preds, targets).item(), self._step)  # type: ignore
示例#9
0
    def _compute_loss(self, batch: Tuple[torch.Tensor, ...]) -> torch.Tensor:
        if self.use_iter:
            tb_prefix = f"{self.tb_log_prefix} " if self.tb_log_prefix else ""
            log(f'{tb_prefix}Training/Beta', self.beta, self.global_iter)

        if self.top_k != 'None':
            loss = self._compute_loss_top_k(batch)

            if self.use_iter:
                self.global_iter += 1
                self.beta = self.get_beta(self.global_iter)

            return loss

        if self.encode_during_sampling:
            loss = self._encode_and_sample_rnn(batch)

            if self.use_iter:
                self.global_iter += 1
                self.beta = self.get_beta(self.global_iter)

            return loss

        src, tgt_context, tgt_words = batch

        # Sample from translator
        self.model.eval()
        dist = Categorical(torch.tensor([self.beta, 1 - self.beta]))
        samp_mask = (dist.sample((src.size(0), )) == 1)

        if torch.sum(samp_mask).item() > 0:
            samp_src = src[samp_mask]
            with torch.no_grad():
                A = time.time()
                samp_tgt_context = self.translator(samp_src)
                print('FIRST BLOCK ' + str(time.time() - A))

                A = time.time()
                samp_tgt_logits = self.teacher(samp_src, samp_tgt_context)
                _, samp_tgt_words = samp_tgt_logits.max(dim=-1)
                print('SECOND BLOCK ' + str(time.time() - A))

            eos_mask = (samp_tgt_context == self.translator.tgt_eos_idx)
            eos_mask |= (samp_tgt_context == self.translator.tgt_pad_idx)
            samp_tgt_words[eos_mask] = self.translator.tgt_pad_idx

            # Merge original and sampled data
            orig_src = src[~samp_mask]
            orig_tgt_context = tgt_context[~samp_mask]
            orig_tgt_words = tgt_words[~samp_mask]

            # Add padding if necessary
            tgt_pad_idx = torch.tensor(self.translator.tgt_pad_idx)
            diff = samp_tgt_context.size(1) - orig_tgt_context.size(1)
            if diff > 0:
                extra = tgt_pad_idx.repeat(
                    (orig_tgt_context.size(0), diff)).to(self.device)
                orig_tgt_context = torch.cat([orig_tgt_context, extra], dim=1)
            diff = samp_tgt_words.size(1) - orig_tgt_words.size(1)
            if diff > 0:
                extra = tgt_pad_idx.repeat(
                    (orig_tgt_words.size(0), diff)).to(self.device)
                orig_tgt_words = torch.cat([orig_tgt_words, extra], dim=1)

            new_src = torch.cat([orig_src, samp_src], dim=0)
            new_tgt_context = torch.cat([orig_tgt_context, samp_tgt_context],
                                        dim=0)
            new_tgt_words = torch.cat([orig_tgt_words, samp_tgt_words], dim=0)
        else:
            new_src = src
            new_tgt_context = tgt_context
            new_tgt_words = tgt_words

        # Train model on merged data
        A = time.time()
        self.model.train()
        pred, target = self.model(new_src, new_tgt_context, new_tgt_words)
        loss = self.loss_fn(pred, target)
        print('THIRD BLOCK ' + str(time.time() - A))

        if self.use_iter:
            self.global_iter += 1
            self.beta = self.get_beta(self.global_iter)

        return loss
示例#10
0
    def _train_step(self) -> None:
        """Run a training step over the training data."""
        self.model.train()
        metrics_with_states: List[Tuple] = [
            (metric, {}) for metric in self.training_metrics
        ]
        self._last_train_log_step = 0

        log_prefix = f"{self.tb_log_prefix} " if self.tb_log_prefix else ""
        log_prefix += 'Training'

        with torch.enable_grad():
            for i in range(self.iter_per_step):

                # print('MEMORY ALLOCATED %f' % float(torch.cuda.memory_allocated() / BYTES_IN_GB))
                # print('MEMORY CACHED %f' % float(torch.cuda.memory_cached() / BYTES_IN_GB))

                t = time.time()
                # Zero the gradients and clear the accumulated loss
                self.optimizer.zero_grad()
                accumulated_loss = 0.0
                for _ in range(self.batches_per_iter):

                    # Get next batch
                    try:
                        batch = next(self._train_iterator)
                    except StopIteration:
                        self._create_train_iterator()
                        batch = next(self._train_iterator)

                    if self.device_count == 1:
                        batch = self._batch_to_device(batch)

                    # Compute loss
                    if self.top_k == 'None':
                        _, _, loss = self._compute_batch(
                            batch, metrics_with_states)
                    else:
                        loss = self._compute_kl_loss(batch)
                    print('LOSS')
                    print(loss)
                    accumulated_loss += loss.item() / self.batches_per_iter

                    loss.backward()
                    # try:
                    #     loss.backward()
                    # except RuntimeError:
                    #     torch.cuda.empty_cache()
                    #     print('EMPTIED CACHE FOR LOSS')
                    #     continue

                # Log loss
                global_step = (self.iter_per_step * self._step) + i
                self.beta = self.get_beta(global_step)

                # Clip gradients if necessary
                if self.max_grad_norm:
                    clip_grad_norm_(self.model.parameters(),
                                    self.max_grad_norm)
                if self.max_grad_abs_val:
                    clip_grad_value_(self.model.parameters(),
                                     self.max_grad_abs_val)

                log(f'{log_prefix}/Loss', accumulated_loss, global_step)
                if self.device_count > 1:
                    log(f'{log_prefix}/Gradient_Norm',
                        self.model.module.gradient_norm, global_step)
                    log(f'{log_prefix}/Parameter_Norm',
                        self.model.module.parameter_norm, global_step)
                else:
                    log(f'{log_prefix}/Gradient_Norm',
                        self.model.gradient_norm, global_step)
                    log(f'{log_prefix}/Parameter_Norm',
                        self.model.parameter_norm, global_step)
                log(f'{log_prefix}/Beta', self.beta, global_step)

                # Optimize
                self.optimizer.step()

                # Update iter scheduler
                if self.iter_scheduler is not None:
                    lr = self.optimizer.param_groups[0]['lr']  # type: ignore
                    log(f'{log_prefix}/LR', lr, global_step)
                    self.iter_scheduler.step()  # type: ignore

                # Zero the gradients when exiting a train step
                self.optimizer.zero_grad()
                # logging train metrics
                if self.extra_training_metrics_log_interval > self._last_train_log_step:
                    self._log_metrics(log_prefix, metrics_with_states,
                                      global_step)
                    self._last_train_log_step = i

                print('TOTAL TIME: %f' % (time.time() - t))
            if self._last_train_log_step != i:
                # log again at end of step, if not logged at the end of
                # step before
                self._log_metrics(log_prefix, metrics_with_states, global_step)
示例#11
0
    def _train_step(self) -> None:
        """Run a training step over the training data."""
        self.model.train()
        metrics_with_states: List[Tuple] = [
            (metric, {}) for metric in self.training_metrics
        ]
        self._last_train_log_step = 0

        log_prefix = f"{self.tb_log_prefix} " if self.tb_log_prefix else ""
        log_prefix += 'Training'

        with torch.enable_grad():
            for i in range(self.iter_per_step):
                # Zero the gradients and clear the accumulated loss
                self.optimizer.zero_grad()
                accumulated_loss = 0.0
                for _ in range(self.batches_per_iter):
                    # Get next batch
                    try:
                        batch = next(self._train_iterator)
                    except StopIteration:
                        self._create_train_iterator()
                        batch = next(self._train_iterator)
                    batch = self._batch_to_device(batch)

                    # Compute loss
                    _, _, loss = self._compute_batch(batch,
                                                     metrics_with_states)
                    accumulated_loss += loss.item() / self.batches_per_iter
                    loss.backward()

                # Log loss
                global_step = (self.iter_per_step * self._step) + i

                # Clip gradients if necessary
                if self.max_grad_norm:
                    clip_grad_norm_(self.model.parameters(),
                                    self.max_grad_norm)
                if self.max_grad_abs_val:
                    clip_grad_value_(self.model.parameters(),
                                     self.max_grad_abs_val)

                log(f'{log_prefix}/Loss', accumulated_loss, global_step)
                log(f'{log_prefix}/Gradient_Norm', self.model.gradient_norm,
                    global_step)
                log(f'{log_prefix}/Parameter_Norm', self.model.parameter_norm,
                    global_step)

                # Optimize
                self.optimizer.step()

                # Update iter scheduler
                if self.iter_scheduler is not None:
                    learning_rate = self.iter_scheduler.get_lr()[
                        0]  # type: ignore
                    log(f'{log_prefix}/LR', learning_rate, global_step)
                    self.iter_scheduler.step()  # type: ignore

                # Zero the gradients when exiting a train step
                self.optimizer.zero_grad()
                # logging train metrics
                if self.extra_training_metrics_log_interval > self._last_train_log_step:
                    self._log_metrics(log_prefix, metrics_with_states,
                                      global_step)
                    self._last_train_log_step = i
            if self._last_train_log_step != i:
                # log again at end of step, if not logged at the end of
                # step before
                self._log_metrics(log_prefix, metrics_with_states, global_step)
示例#12
0
文件: sampler.py 项目: al5250/imitkd
    def sample(self,
               data: Sequence[Sequence[torch.Tensor]],
               n_epochs: int = 1) -> Iterator[Tuple[torch.Tensor, ...]]:
        if self.translator is None or self.teacher is None:
            raise ValueError(
                'Cannot sample because one of student or teacher is missing!')

        samp_per_epoch = math.ceil(self.length(data) / self.sample_factor)
        N = n_epochs * samp_per_epoch
        beta = 1.

        if self.scheduler_type == 'linear':
            get_beta = lambda t: 1. - t / N
        elif self.scheduler_type == 'exponential':
            get_beta = lambda t: 200**(-t / N)
        elif self.scheduler_type == 'reverse_sigmoid':
            get_beta = lambda t: 1 / (1 + np.exp((t / N - 0.5) * 20))
        elif self.scheduler_type == 'ones':
            get_beta = lambda t: 1.
        elif self.scheduler_type == 'zeros':
            get_beta = lambda t: 0.
        else:
            raise ValueError('Not implemented!')

        if len(data) == 0:
            raise ValueError("No examples provided")

        collate_fn_p = partial(collate_fn, pad=self.pad)

        def collate_with_filter(batch):
            batch = filter(
                lambda lst: all([len(x) <= self.max_seq_len for x in lst]),
                batch)
            return collate_fn_p(batch)

        sample_batch_size = self.batch_size * self.sample_factor
        loader = DataLoader(
            dataset=data,  # type: ignore
            shuffle=self.shuffle,
            batch_size=sample_batch_size,
            collate_fn=collate_with_filter,
            num_workers=self.n_workers,
            pin_memory=self.pin_memory,
            drop_last=self.drop_last)

        for epoch in range(n_epochs):
            for samp_count, batch in enumerate(loader):

                batch = [x.clone() for x in batch]
                src, tgt_context, tgt_words = batch
                max_seq_len = self.translator.max_seq_len
                pad_idx = self.translator.tgt_pad_idx
                tgt_context = pad_to_len(tgt_context,
                                         max_seq_len,
                                         pad_idx,
                                         dim=1)
                tgt_words = pad_to_len(tgt_words, max_seq_len, pad_idx, dim=1)

                beta = get_beta(epoch * samp_per_epoch + samp_count)
                dist = Categorical(torch.tensor([beta, 1 - beta]))
                samp_mask = (dist.sample((src.size(0), )) == 1)

                if torch.sum(samp_mask).item() > 0:
                    samp_src = src[samp_mask].to(self.device)
                    self.translator.model.eval()

                    with torch.no_grad():
                        samp_tgt_context = self.translator(samp_src)
                        samp_tgt_logits = self.teacher(samp_src,
                                                       samp_tgt_context)
                        _, samp_tgt_words = samp_tgt_logits.max(dim=-1)

                    # Pad words correctly
                    eos_idx = self.translator.tgt_eos_idx
                    pad_idx = self.translator.tgt_pad_idx
                    eos_mask = (samp_tgt_context == eos_idx)
                    eos_mask |= (samp_tgt_context == pad_idx)
                    samp_tgt_words[eos_mask] = pad_idx

                    samp_src = samp_src.cpu()
                    samp_tgt_context = samp_tgt_context.cpu()
                    samp_tgt_words = samp_tgt_words.cpu()

                    tgt_context[samp_mask] = samp_tgt_context
                    tgt_words[samp_mask] = samp_tgt_words

                    max_len = (tgt_words != pad_idx).sum(dim=1).max().item()
                    tgt_context = tgt_context[:, :max_len]
                    tgt_words = tgt_words[:, :max_len]

                    self.translator.model.train()

                B = self.batch_size
                for i in range(self.sample_factor):
                    step = self.sample_factor * (epoch * samp_per_epoch +
                                                 samp_count) + i
                    log('Training/Beta', beta, step)
                    src_slice = src[i * B:(i + 1) * B]
                    tgt_context_slice = tgt_context[i * B:(i + 1) * B]
                    tgt_words_slice = tgt_words[i * B:(i + 1) * B]
                    yield src_slice, tgt_context_slice, tgt_words_slice
示例#13
0
文件: train.py 项目: taoleicn/flop
 def _eval_step(self) -> None:
     super()._eval_step()
     log_masks(self.model, self.hard_concrete_masks, self._step)
     num_parameters = get_num_params(self.hard_concrete_modules, train=False)
     num_non_prunable = self.init_num_params - self.max_prunable
     total_num_params = int(num_parameters) + num_non_prunable
     relative_sparsity = 1. - (num_parameters / self.max_prunable)
     log("Num_Params", int(num_parameters), self._step)
     log("Relative_sparsity", float(relative_sparsity), self._step)
     log("True_sparsity", 1. - total_num_params / self.init_num_params, self._step)
     log("Total_num_params", total_num_params, self._step)
     log('LambdaLR', self.optimizer_lambdas.param_groups[0]['lr'], self._step)  # type: ignore
     log('AlphaLR', self.optimizer_alphas.param_groups[0]['lr'], self._step)  # type: ignore
示例#14
0
文件: train.py 项目: taoleicn/flop
    def __init__(self,
                 dataset: Dataset,
                 train_sampler: Sampler,
                 val_sampler: Sampler,
                 model: Module,
                 loss_fn: Metric,
                 metric_fn: Metric,
                 optimizer,
                 scheduler=None,
                 device: Optional[str] = None,
                 max_steps: int = 10,
                 epoch_per_step: float = 1.0,
                 iter_per_step: Optional[int] = None,
                 batches_per_iter: int = 1,
                 lower_is_better: bool = False,
                 max_grad_norm: Optional[float] = None,
                 max_grad_abs_val: Optional[float] = None,
                 extra_validation_metrics: Optional[List[Metric]] = None,
                 lr_warmup: int = 100,
                 model_dim: int = 512,
                 iter_before_pruning: int = 0,
                 init_mean: float = 0.5,
                 init_std: float = 0.01,
                 alphas_lr: float = 0.001,
                 lambdas_lr: float = 1.0,
                 target_sparsity: float = 0.8,
                 target_sparsity_warmup: int = 80000,
                 weight_decay: float = 0) -> None:
        """Initialize the Trainer.

        Parameters
        ----------
        dataset: Dataset
            The dataset containing the first N columns of data for the
            student model, and the last N columns for the target.
        train_sampler : Sampler
            The sampler to use over training examples
        val_sampler : Sampler
            The sampler to use over validation examples
        model : Module
            The model to train
        optimizer : torch.optim.Optimizer
            The optimizer to use
        scheduler : torch.optim.lr_scheduler._LRScheduler, optional
            An optional learning rate scheduler
        device: str, optional
            The device to use in the computation. Only used by compile.
        max_steps : int, optional
            The maximum number of training steps to run
        epoch_per_step : float, optional
            Fraction of an epoch to perform in a single training step
            (i.e before a checkpoint.) Defaults to 1.
            Overriden by `iter_per_step`, if given.
        iter_per_step : int, optional
            Number of iterations to perform in a single training step.
            Overrides `epoch_per_step` if given.
        batches_per_iter : int, optional
            Number of batches to pass through the model before
            calling optimizer.step. Requires the sampler to have
            drop_last set to True. (default set to 1 so optimizer.step
            is called after every batch)
        lower_is_better : bool, optional
            If true, the lowest dev metric is considered best,
            otherwise the highest. Defaults to False.
        max_grad_norm : float, optional
            Maximum Euclidean norm of gradient after clipping.
        max_grad_abs_val: float, optional
            Maximum absolute value of all gradient vector components
            after clipping.
        extra_validation_metrics: Optional[List[Metric]]
            A list with extra metrics to show in each step
            but which don't guide the training procedures
            (i.e model selection through early stopping)

        """
        super().__init__(dataset,
                         train_sampler,  # type: ignore
                         val_sampler,
                         model,
                         loss_fn,
                         metric_fn,
                         optimizer,
                         scheduler,
                         device,
                         max_steps,
                         epoch_per_step,
                         iter_per_step,
                         batches_per_iter,
                         lower_is_better,
                         max_grad_norm,
                         max_grad_abs_val,
                         extra_validation_metrics)

        if device is None:
            device = "cuda" if torch.cuda.is_available() else "cpu"

        self.lambda_1 = nn.Parameter(torch.tensor(0.).to(device))  # type: ignore
        self.lambda_2 = nn.Parameter(torch.tensor(0.).to(device))  # type: ignore

        self.optimizer_lambdas = Adam([self.lambda_1, self.lambda_2], weight_decay=0)
        self.optimizer_lambdas.param_groups[0]['lr'] = -lambdas_lr  # type: ignore
        self.lambdas_scheduler = NoamScheduler(self.optimizer_lambdas,
                                               warmup=lr_warmup,
                                               d_model=model_dim)

        self.iter_before_pruning = iter_before_pruning
        self.init_num_params = sum([len(p.view(-1)) for p in self.model.parameters()])

        make_hard_concrete(self.model, in_place=True, init_mean=init_mean, init_std=init_std)

        self.hard_concrete_modules = get_hardconcrete_proj_linear_modules(self.model)
        self.hard_concrete_masks = get_hardconcrete_modules(self.model)
        self.max_prunable = get_num_prunable_params(self.hard_concrete_modules)
        self.target_sparsity = max(min(target_sparsity, 1.0), 0.0)
        self.target_sparsity_warmup = target_sparsity_warmup

        self.model.to(device)

        model_params = (p for n, p in self.model.named_parameters() if 'log_alpha' not in n)
        alpha_params = (p for n, p in self.model.named_parameters() if 'log_alpha' in n)

        self.optimizer_alphas = Adam(alpha_params,
                                     lr=alphas_lr)  # type: ignore
        self.alphas_scheduler = NoamScheduler(self.optimizer_alphas,
                                              warmup=lr_warmup,
                                              d_model=model_dim)

        self.optimizer = Adam(model_params,
                              lr=self.optimizer.param_groups[0]['lr'],  # type: ignore
                              weight_decay=weight_decay)
        self.lr_scheduler = NoamScheduler(self.optimizer,
                                          warmup=lr_warmup,
                                          d_model=model_dim)

        log('Total_params', int(self.init_num_params), 0)
        log('Max_prunable', int(self.max_prunable), 0)
示例#15
0
文件: train.py 项目: taoleicn/flop
    def _train_step(self) -> None:
        """Run a training step over the training data."""
        self.model.train()

        tb_prefix = f"{self.tb_log_prefix} " if self.tb_log_prefix else ""

        with torch.enable_grad():
            for i in range(self.iter_per_step):
                # Zero the gradients and clear the accumulated loss
                self.optimizer.zero_grad()
                self.optimizer_alphas.zero_grad()
                self.optimizer_lambdas.zero_grad()

                accumulated_loss = 0.0
                for _ in range(self.batches_per_iter):
                    # Get next batch
                    batch = next(self._train_iterator)
                    batch = self._batch_to_device(batch)

                    # Compute loss
                    loss = self._compute_loss(batch) / self.batches_per_iter
                    accumulated_loss += loss.item()
                    loss.backward()

                # Log loss
                global_step = (self.iter_per_step * self._step) + i

                # Clip gradients if necessary
                if self.max_grad_norm:
                    clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
                if self.max_grad_abs_val:
                    clip_grad_value_(self.model.parameters(), self.max_grad_abs_val)

                log(f'{tb_prefix}Training/Loss', accumulated_loss, global_step)
                log(f'{tb_prefix}Training/Gradient_Norm', self.model.gradient_norm, global_step)
                log(f'{tb_prefix}Training/Parameter_Norm', self.model.parameter_norm, global_step)

                if global_step >= self.iter_before_pruning:

                    pruning_step = global_step - self.iter_before_pruning

                    num_parameters = get_num_params(self.hard_concrete_modules, train=True)
                    expected_sparsity = 1. - (num_parameters / self.max_prunable)

                    if self.target_sparsity_warmup > 0:
                        factor = min(1.0, pruning_step / self.target_sparsity_warmup)
                        target_sparsity = self.target_sparsity * factor
                    else:
                        target_sparsity = self.target_sparsity

                    lagrangian_loss = self.lambda_1 * (target_sparsity - expected_sparsity)
                    lagrangian_loss += self.lambda_2 * (target_sparsity - expected_sparsity) ** 2
                    lagrangian_loss.backward()
                    log("Expected_sparsity", float(expected_sparsity), global_step)
                    log("Lagrangian_loss", lagrangian_loss.item(), global_step)
                    log("Target_sparsity", target_sparsity, global_step)
                    log("lambda_1", self.lambda_1.item(), global_step)
                    log("lambda_2", self.lambda_2.item(), global_step)

                    self.optimizer_lambdas.step()
                    self.lambdas_scheduler.step(pruning_step)

                    self.optimizer_alphas.step()
                    self.alphas_scheduler.step(pruning_step)

                # Optimize
                self.optimizer.step()
                self.lr_scheduler.step(global_step)

            # Zero the gradients when exiting a train step
            self.optimizer.zero_grad()
            self.optimizer_lambdas.zero_grad()
            self.optimizer_alphas.zero_grad()
示例#16
0
    def run(self, block_name: str = None) -> bool:
        self.model.to(self.device)
        self.model.eval()
        if self.teacher is not None:
            self.teacher.to(self.device)
            self.teacher.eval()

        with torch.no_grad():
            preds, targets = [], []
            if self.save_preds:
                sources = []
            if self.teacher is not None:
                teacher_preds = []

            for batch in self._eval_iterator:
                batch = [t.to(self.device) for t in batch]
                pred = self.model(batch[0], gen_style=self.gen_style)
                if self.save_preds:
                    source = batch[0]
                    if source.size(1) < 50:
                        extra = torch.zeros(
                            (source.size(0),
                             50 - source.size(1))).long().to(self.device)
                        source = torch.cat([source, extra], dim=1)
                    sources.append(source)
                    if self.teacher is not None:
                        student_context = pred[:, :-1].to(self.device)
                        _, teacher_pred = self.teacher(
                            batch[0], student_context).max(dim=2)
                        teacher_pred[student_context ==
                                     self.model.eos_idx] = self.model.pad_idx
                        teacher_pred[student_context ==
                                     self.model.pad_idx] = self.model.pad_idx
                        teacher_preds.append(teacher_pred.cpu())
                target = batch[1]
                if target.size(1) < 50:
                    extra = torch.zeros(
                        (target.size(0),
                         50 - target.size(1))).long().to(self.device)
                    target = torch.cat([target, extra], dim=1)
                preds.append(pred.cpu())
                targets.append(target.cpu())

            preds = torch.cat(preds, dim=0)  # type: ignore
            if self.save_preds:
                sources = torch.cat(sources, dim=0)
                if self.teacher is not None:
                    teacher_preds = torch.cat(teacher_preds, dim=0)
                    self.decode_data = (sources, preds[:, :-1], teacher_preds)
                else:
                    self.decode_data = (sources, preds[:, :-1], preds[:, 1:])
            targets = torch.cat(targets, dim=0)  # type: ignore
            if self.save_targets:
                self.targets = targets
            self.eval_metric = self.metric_fn(preds, targets).item()

            tb_prefix = f"{self.tb_log_prefix} " if self.tb_log_prefix else ""

            log(
                f'{tb_prefix}Eval {self.metric_fn}',  # type: ignore
                self.eval_metric,
                global_step=0)  # type: ignore

        return False