예제 #1
0
    def save_model(self, epoch=None, save_name=None):
        if save_name is None:
            save_name = 'model.epoch.%d.pt' % epoch

        if self.mixed_precision:
            import apex.amp as amp
            amp_state_dict = amp.state_dict()
        else:
            amp_state_dict = None

        checkpoint = {
            'epoch':
            epoch,
            'params':
            self.params,
            'model':
            self.model.module.state_dict()
            if self.ngpu > 1 else self.model.state_dict(),
            'optimizer':
            self.optimizer.state_dict(),
            'amp':
            amp_state_dict
        }

        torch.save(checkpoint, os.path.join(self.expdir, save_name))
예제 #2
0
    def test_loss_scale_decrease(self):
        num_losses = 3
        nb_decrease_loss_scales = [0, 1, 2]
        for opt_level in self.test_opt_levels:
            #print('#' * 75 + f'\n opt_level {opt_level}\n')
            # Create new tmp copy for this run
            nb_decrease_loss_scales_tmp = list(nb_decrease_loss_scales)

            model = MyModel().to('cuda')

            optimizer = optim.SGD(model.parameters(),
                                  lr=self.initial_lr)

            model, optimizer = amp.initialize(
                model, optimizer, opt_level=opt_level, num_losses=num_losses,
                verbosity=0)

            if amp._amp_state.opt_properties.loss_scale != 'dynamic':
                #print('Static loss scale set. Skipping opt_level.')
                continue

            # force to skip some updates to decrease the loss_scale
            initial_loss_scales = []
            for idx in range(num_losses):
                initial_loss_scales.append(
                    amp._amp_state.loss_scalers[idx].loss_scale())

            for _ in range(len(nb_decrease_loss_scales)):
                x = torch.randn(16, 3, 24, 24, device='cuda')
                for idx in range(num_losses):
                    while nb_decrease_loss_scales_tmp[idx] > 0:
                        optimizer.zero_grad()
                        output = model(x * 2**17)
                        loss = output.mean()

                        with amp.scale_loss(loss, optimizer, loss_id=idx) as scaled_loss:
                            scaled_loss.backward(retain_graph=True)
                        optimizer.step()
                        nb_decrease_loss_scales_tmp[idx] -= 1

            # Check loss scales afterwards
            updated_loss_scales = []
            for idx in range(num_losses):
                updated_loss_scales.append(
                    amp._amp_state.loss_scalers[idx].loss_scale())
            for factor, update_ls, init_ls in zip(nb_decrease_loss_scales,
                                                  updated_loss_scales,
                                                  initial_loss_scales):
                self.assertEqual(update_ls, init_ls / 2**factor)

            # Check state dict
            amp_state_dict = amp.state_dict()
            for scaler_idx, factor, init_ls in zip(amp_state_dict,
                                                   nb_decrease_loss_scales,
                                                   initial_loss_scales):
                scaler = amp_state_dict[scaler_idx]
                self.assertEqual(scaler['loss_scale'], init_ls / 2**factor)
                unskipped_target = 0
                self.assertEqual(scaler['unskipped'], unskipped_target)
    def dump_checkpoint(self, weights_only: bool = False) -> dict:
        """Creating model checkpoint.

        Args:
            weights_only: saving model weights only

        Return:
             structured dictionary
        """
        checkpoint = {
            'epoch': self.trainer.current_epoch + 1,
            'global_step': self.trainer.global_step + 1,
            'pytorch-lightning_version': pytorch_lightning.__version__,
        }

        if not weights_only:

            # save callbacks
            callback_states = self.trainer.on_save_checkpoint()
            checkpoint['callbacks'] = callback_states

            # save optimizers
            optimizer_states = []
            for i, optimizer in enumerate(self.trainer.optimizers):
                optimizer_states.append(optimizer.state_dict())
            checkpoint['optimizer_states'] = optimizer_states

            # save lr schedulers
            lr_schedulers = []
            for scheduler in self.trainer.lr_schedulers:
                lr_schedulers.append(scheduler['scheduler'].state_dict())
            checkpoint['lr_schedulers'] = lr_schedulers

            # save native amp scaling
            if self.trainer.amp_backend == AMPType.NATIVE and not self.trainer.use_tpu and self.trainer.scaler is not None:
                checkpoint['native_amp_scaling_state'] = self.trainer.scaler.state_dict()
            elif self.trainer.amp_backend == AMPType.APEX:
                checkpoint['amp_scaling_state'] = amp.state_dict()

        # add the module_arguments and state_dict from the model
        model = self.trainer.get_model()

        checkpoint['state_dict'] = model.state_dict()

        if model.hparams:
            if hasattr(model, '_hparams_name'):
                checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_NAME] = model._hparams_name
            # add arguments to the checkpoint
            if OMEGACONF_AVAILABLE:
                checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = model.hparams
                if isinstance(model.hparams, Container):
                    checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_TYPE] = type(model.hparams)
            else:
                checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = dict(model.hparams)

        # give the model a chance to add a few things
        model.on_save_checkpoint(checkpoint)

        return checkpoint
예제 #4
0
def save_checkpoint(net, optimizer, lr_scheduler, is_mixed_precision, filename='temp.pt'):
    checkpoint = {
        'model': net.state_dict(),
        'optimizer': optimizer.state_dict() if optimizer is not None else None,
        'lr_scheduler': lr_scheduler.state_dict() if lr_scheduler is not None else None,
        'amp': amp.state_dict() if is_mixed_precision else None
    }
    torch.save(checkpoint, filename)
예제 #5
0
 def _save(self, chkpt_name):
     # Save checkpoint
     checkpoint = {
         'model': self.net.state_dict(),
         'optimizer': self.optimizer.state_dict(),
         'amp': amp.state_dict()
     }
     torch.save(checkpoint, chkpt_name)
예제 #6
0
def save_checkpoints(model, path, optimizer=None, niters=0):
    torch.save(
        {
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
            "amp": amp.state_dict(),
            "niters": niters,
        }, path)
예제 #7
0
 def __call__(self, engine):
     checkpoint = {
         'model': engine.model.state_dict(),
         'optimizer': engine.optimizer.state_dict(),
     }
     if engine.FP16:
         checkpoint['amp'] = amp.state_dict()
     torch.save(checkpoint, self.path)
    def save(self, num):
        save_data = {'GAN': self.GAN.state_dict()}

        if self.GAN.fp16:
            save_data['amp'] = amp.state_dict()

        torch.save(save_data, self.model_name(num))
        self.write_config()
예제 #9
0
    def epoch_event_function(engine):
        if args.test_during_training:
            evaluator_for_train.run(
                train_loader
            )  # It is better to re-make a train_loader_for_evaluation so as not to disturb the random number generator.
            performance = evaluator_for_train.state.metrics['IQA_performance']
            writer_add_scalar(writer, 'train', args.dataset, performance,
                              engine.state.epoch)
            k = performance['k']
            b = performance['b']
        else:
            k = [1, 1, 1]
            b = [0, 0, 0]

        evaluator = create_supervised_evaluator(model,
                                                metrics={
                                                    'IQA_performance':
                                                    IQAPerformance(
                                                        status='test',
                                                        k=k,
                                                        b=b,
                                                        mapping=mapping)
                                                },
                                                device=device)
        evaluator.run(val_loader)
        performance = evaluator.state.metrics['IQA_performance']
        writer_add_scalar(writer, 'val', args.dataset, performance,
                          engine.state.epoch)
        val_criterion = abs(
            performance[args.val_criterion]
        )  # when alpha=[0,1],loss_type='linearity', test_during_training=False, SROCC/PLCC can be negative during training.
        if args.test_during_training:
            evaluator.run(test_loader)
            performance = evaluator.state.metrics['IQA_performance']
            writer_add_scalar(writer, 'test', args.dataset, performance,
                              engine.state.epoch)

        global best_val_criterion, best_epoch
        if val_criterion > best_val_criterion:  # If RMSE is used, then change ">" to "<".
            checkpoint = {
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'amp': amp.state_dict(),
                'k': k,
                'b': b
            }
            torch.save(checkpoint, args.trained_model_file)
            best_val_criterion = val_criterion
            best_epoch = engine.state.epoch
            print(
                'Save current best model @best_val_criterion ({}): {:.3f} @epoch: {}'
                .format(args.val_criterion, best_val_criterion, best_epoch))
        else:
            print(
                'Model is not updated @val_criterion ({}): {:.3f} @epoch: {}'.
                format(args.val_criterion, val_criterion, engine.state.epoch))

        scheduler.step(engine.state.epoch)
예제 #10
0
    def save(self, path):
        checkpoint = {
            'model': self.net.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'amp': amp.state_dict()
        }

        torch.save(checkpoint, path)
        print(f'Saved model to {path}')
예제 #11
0
 def checkpoint(self):
     best_checkpoint = {
         "decoder": deepcopy(self.decoder.state_dict()),
         "optimizer": deepcopy(self.optimizer.state_dict()),
         "schedular": deepcopy(self.schedular.state_dict()),
         "step": self.step,
     }
     best_checkpoint['amp'] = deepcopy(amp.state_dict())
     return best_checkpoint
예제 #12
0
 def save(self):
     checkpoint = {'global_step': self.global_step,
                   'model_state_dict': _to_cpu(self.model.state_dict()),
                   'optim_state_dict': _to_cpu(self.optimizer.state_dict())}
     if self.amp:
         checkpoint['amp_state_dict'] = amp.state_dict()
     if exists(self.save_path):
         os.rename(self.save_path, self.backup_path)
     torch.save(checkpoint, self.save_path)
예제 #13
0
 def save(self):
     checkpoint_to_save = {'global_step': self.global_step}
     for k in self.ckpt_dict:
         checkpoint_to_save[k] = _to_cpu(self.ckpt_dict[k].state_dict())
     if self.amp:
         checkpoint_to_save['amp_state_dict'] = amp.state_dict()
     if exists(self.save_path):
         os.rename(self.save_path, self.backup_path)
     torch.save(checkpoint_to_save, self.save_path)
예제 #14
0
def update_train_state(args, model, optimizer, train_state):
    # Save one model at least
    if train_state['epoch_index'] == 0:
        checkpoint = {
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
        }
        if args.fp16: checkpoint['amp'] = amp.state_dict()
        torch.save(checkpoint, train_state['model_filename'])

        train_state['stop_early'] = False

    # Save model if performance improved
    elif train_state['epoch_index'] >= 1:
        loss_tm1, loss_t = train_state['val_running_loss'][-2:]

        # If loss worsened
        if loss_t >= loss_tm1:
            # Update step
            train_state['early_stopping_step'] += 1
        # Loss decreased
        else:
            # Save the best model
            if loss_t < train_state['early_stopping_best_val']:

                checkpoint = {
                    'model': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'cursor_train': train_state['cursor_train'],
                    'cursor_val': train_state['cursor_val'],
                }
                if args.fp16: checkpoint['amp'] = amp.state_dict()
                torch.save(checkpoint, train_state['model_filename'])

                train_state['early_stopping_best_val'] = loss_t

            # Reset early stopping step
            train_state['early_stopping_step'] = 0

        # Stop early ?
        train_state['stop_early'] = \
            train_state['early_stopping_step'] >= args.early_stopping_criteria

    return train_state
예제 #15
0
def save(args, model, optimizer):
    if args.local_rank == -1 or (args.is_distributed and args.global_rank == 0):
        torch.save(model.state_dict(),os.path.join(args.output_dir,'model_checkpoint.pt'))
        if args.fp16:
            checkpoint = {
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'amp': amp.state_dict()
            }
            torch.save(checkpoint, os.path.join(args.output_dir,'amp_checkpoint.pt'))
예제 #16
0
 def save(self, path):
     self.model.eval()
     torch.save({
         'model_state_dict': self.model.state_dict(),
         'optimizer_state_dict': self.optimizer.state_dict(),
         'scheduler_state_dict': self.scheduler.state_dict(),
         'best_summary_loss': self.best_summary_loss,
         'epoch': self.epoch,
         'amp': amp.state_dict() # apex
     }, path)
예제 #17
0
def save_checkpoint(file_path, epoch, model, optimizer, use_apex=False):
    """Save current training state"""
    ckpt = {
        "epoch": epoch,
        "model": model.create_ckpt(),
        "optimizer": optimizer.state_dict()
    }
    if use_apex:
        ckpt["apex"] = amp.state_dict()
    torch.save(ckpt, file_path)
예제 #18
0
 def save_checkpoint(self, epoch):
     if epoch % self.save_epoch == 0 and self.rank == 0:
         state = {'config': self.config,
                  'epoch': epoch,
                  'steps': self.steps,
                  'model': self.model.state_dict(),
                  'optimizer': self.optimizer.state_dict(),
                  'amp': amp.state_dict()
                 }
         torch.save(state, self.ckpt_path.format(epoch))
예제 #19
0
def save_checkpoint(model, opt, iteration, best_metric, out_path, fp16):
    checkpoint = {
        "iteration": iteration,
        "best_metric": best_metric,
        "state_dict": model.state_dict(),
        "opt": opt.state_dict()
    }
    if fp16:
        checkpoint.update({"amp": amp.state_dict()})
    torch.save(checkpoint, out_path)
예제 #20
0
 def get_state_dict(self):
     print('\n Getting state dict. \n')
     save_data = {
         'GAN': self.GAN.state_dict(),
         'G_opt': self.GAN.G_opt.state_dict(),
         'D_opt': self.GAN.D_opt.state_dict()
     }
     if self.GAN.fp16:
         save_data['amp'] = amp.state_dict()
     return save_data
 def save_checkpoint(self, filename, extra_state):
     """Save all training state in a checkpoint file."""
     if self.args.amp:
         extra_state['amp_state_dict'] = amp.state_dict()
         extra_state['amp_master_params'] = list(amp.master_params(self.optimizer.optimizer))
     if distributed_utils.is_master(self.args):  # only save one checkpoint
         utils.save_state(
             filename, self.args, self.get_model(), self.criterion, self.optimizer,
             self.lr_scheduler, self._num_updates, self._optim_history, extra_state,
         )
def save_model(model: TEDD1104,
               save_dir: str,
               fp16,
               amp_opt_level: str = None) -> None:
    """
    Save model to a directory. This function stores two files, the hyperparameters and the weights.

    Input:
     - model: TEDD1104 model to save
     - save_dir: directory where the model will be saved, if it doesn't exists we create it
     - amp: If the model uses FP16, Nvidia Apex AMP
     - amp_opt_level: If the model uses FP16, the AMP opt_level

    Output:

    """

    if fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    dict_hyperparams: dict = {
        "resnet": model.resnet,
        "pretrained_resnet": model.pretrained_resnet,
        "sequence_size": model.sequence_size,
        "embedded_size": model.embedded_size,
        "hidden_size": model.hidden_size,
        "num_layers_lstm": model.num_layers_lstm,
        "bidirectional_lstm": model.bidirectional_lstm,
        "layers_out": model.layers_out,
        "dropout_cnn": model.dropout_cnn,
        "dropout_cnn_out": model.dropout_cnn_out,
        "dropout_lstm": model.dropout_lstm,
        "dropout_lstm_out": model.dropout_lstm_out,
        "fp16": fp16,
        "amp_opt_level": amp_opt_level,
    }

    model_weights: dict = {
        "model": model.state_dict(),
        "amp": None if not fp16 else amp.state_dict(),
    }

    with open(os.path.join(save_dir, "model_hyperparameters.json"),
              "w+") as file:
        json.dump(dict_hyperparams, file)

    torch.save(obj=model_weights, f=os.path.join(save_dir, "model.bin"))
예제 #23
0
    def __init__(
        self,
        evaluator,
        env_gen,
        optim=None,
        memory_queue=None,
        iterations=100,
        temperature_cutoff=5,
        batch_size=64,
        memory_size=200000,
        min_memory=20000,
        update_nn=True,
        starting_state_dict=None,
    ):
        self.iterations = iterations
        self.evaluator = evaluator.to(device)
        self.env_gen = env_gen
        self.optim = optim
        self.env = env_gen()
        self.root_node = None
        self.reset()
        self.update_nn = update_nn
        self.starting_state_dict = starting_state_dict

        self.memory_queue = memory_queue
        self.temp_memory = []
        self.memory = Memory(memory_size)
        self.min_memory = min_memory
        self.temperature_cutoff = temperature_cutoff
        self.actions = self.env.action_space.n

        self.evaluating = False

        self.batch_size = batch_size

        if APEX_AVAILABLE:
            opt_level = "O1"

            if self.optim:
                self.evaluator, self.optim = amp.initialize(
                    evaluator, optim, opt_level=opt_level)
                print("updating optimizer and evaluator")
            else:
                self.evaluator = amp.initialize(evaluator, opt_level=opt_level)
                print(" updated evaluator")
            self.amp_state_dict = amp.state_dict()
            print(vars(amp._amp_state))
        elif APEX_AVAILABLE:
            opt_level = "O1"
            print(vars(amp._amp_state))

        if self.starting_state_dict:
            print("laoding [sic] state dict in mcts")
            self.load_state_dict(self.starting_state_dict)
예제 #24
0
 def save(self, epoch, model, optimizer, losses, train_step):
   model.cpu()
   torch.save({
     'epoch': epoch,  # 현재 학습 epoch
     'model_state_dict': model.state_dict(),  # 모델 저장
     'optimizer_state_dict': optimizer.state_dict(),  # 옵티마이저 저장
     'losses': losses,  # Loss 저장
     'train_step': train_step,  # 현재 진행한 학습
     'amp': amp.state_dict()
   }, f'{self.checkpoint_path}/{self.model_name}.pth')
   model.cuda()
예제 #25
0
def save_checkpoint(decoder, optimizer, amp, scheduler, step, checkpoint_dir):
    checkpoint_state = {
        "vocoder": decoder.state_dict(),
        "optimizer": optimizer.state_dict(),
        "amp": amp.state_dict(),
        "scheduler": scheduler.state_dict(),
        "step": step}
    checkpoint_dir.mkdir(exist_ok=True, parents=True)
    checkpoint_path = checkpoint_dir / "model.ckpt-{}.pt".format(step)
    torch.save(checkpoint_state, checkpoint_path)
    print("Saved checkpoint: {}".format(checkpoint_path.stem))
예제 #26
0
 def get_state(self, net, opt, step):
     try:
         net_dict = net.module.state_dict()
     except AttributeError:
         net_dict = net.state_dict()
     state = dict(step=step, net=net_dict, opt=opt.state_dict())
     try:
         state['amp'] = amp.state_dict()
     except:
         pass
     return to_torch(state, device='cpu')
예제 #27
0
    def validation_end(self, outputs):
        # OPTIONAL
        self.logger.experiment.add_text('test', 'This is test', 0)

        avg_wer = np.mean([x['wer'] for x in outputs])
        ppl = np.mean([x['val_loss'] for x in outputs])
        self.logger.experiment.add_scalar('val/WER', avg_wer, self.steps)
        self.logger.experiment.add_scalar('val/perplexity', ppl, self.steps)

        hypothesis, ground_truth = '', ''
        for idx in range(min(5, len(outputs))):
            hypothesis += outputs[idx]['hypothesis'] + '\n\n'
            ground_truth += outputs[idx]['ground_truth'] + '\n\n'

        self.logger.experiment.add_text('generated', hypothesis, self.steps)
        self.logger.experiment.add_text('grouth_truth', ground_truth,
                                        self.steps)
        if self.latest_alignment != None:
            alignment = self.latest_alignment
            idx = random.randint(0, alignment.size(0) - 1)
            alignment = torch.softmax(alignment[idx], dim=-1)
            alignment[:, :, 0] = 0  # ignore blank token
            alignment = alignment.mean(dim=-1)

            self.logger.experiment.add_image("alignment",
                                             plot_alignment_to_numpy(
                                                 alignment.data.numpy().T),
                                             self.steps,
                                             dataformats='HWC')
        self.logger.experiment.flush()

        if self.best_wer > avg_wer:
            print('best checkpoint found!')
            checkpoint = {
                'model': self.model.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'epoch': self.epoch
            }
            if self.args.apex:
                checkpoint['amp'] = amp.state_dict()
            torch.save(
                checkpoint,
                os.path.join(self.args.log_path,
                             str(self.epoch) + 'amp_checkpoint.pt'))
            self.best_wer = avg_wer

        self.plateau_scheduler.step(avg_wer)
        self.epoch += 1

        return {
            'val/WER': torch.tensor(avg_wer),
            'wer': torch.tensor(avg_wer),
            'val/perplexity': torch.tensor(ppl)
        }
예제 #28
0
    def _save_checkpoint(self) -> str:
        """
        Save the model's current parameters and the training state to a
        checkpoint.

        The training state contains the total number of training steps,
        the total number of training tokens,
        the best checkpoint score and iteration so far,
        and optimizer and scheduler states.

        """
        model_path = "{}/{}.ckpt".format(self.model_dir, self.stats.steps)
        model_state_dict = self.model.module.state_dict() \
            if isinstance(self.model, torch.nn.DataParallel) \
            else self.model.state_dict()
        state = {
            "steps":
            self.stats.steps,
            "total_tokens":
            self.stats.total_tokens,
            "best_ckpt_score":
            self.stats.best_ckpt_score,
            "best_ckpt_iteration":
            self.stats.best_ckpt_iter,
            "model_state":
            model_state_dict,
            "optimizer_state":
            self.optimizer.state_dict(),
            "scheduler_state":
            self.scheduler.state_dict()
            if self.scheduler is not None else None,
            'amp_state':
            amp.state_dict() if self.fp16 else None
        }
        torch.save(state, model_path)
        if self.ckpt_queue.full():
            to_delete = self.ckpt_queue.get()  # delete oldest ckpt
            try:
                os.remove(to_delete)
            except FileNotFoundError:
                logger.warning(
                    "Wanted to delete old checkpoint %s but "
                    "file does not exist.", to_delete)

        self.ckpt_queue.put(model_path)

        best_path = "{}/best.ckpt".format(self.model_dir)
        try:
            # create/modify symbolic link for best checkpoint
            symlink_update("{}.ckpt".format(self.stats.steps), best_path)
        except OSError:
            # overwrite best.ckpt
            torch.save(state, best_path)
        return best_path
예제 #29
0
    def save_network(self, network, network_label, epoch_label, gpu_ids):
        if self.master:
            save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
            save_path = os.path.join(self.save_dir, save_filename)

            if self.opt.distributed:
                torch.save(network.module.state_dict(), save_path)
                save_path = save_path.replace(network_label, 'amp')
                torch.save(amp.state_dict(), save_path)
            else:
                torch.save(network.state_dict(), save_path)
예제 #30
0
def save_checkpoint(model, optimizer, epoch, config, filepath, amp=None):
    print("Saving model and optimizer state at epoch {} to {}".format(
        epoch, filepath))
    torch.save(
        {
            'epoch': epoch,
            'config': config,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'amp': amp.state_dict() if amp is not None else None
        }, filepath)