class DDPTrainer(object):
    """Main class for data parallel training.

    This class supports data parallel training, where multiple workers each
    have a full model replica and gradients are accumulated synchronously via
    torch.distributed.all_reduce.
    """

    def __init__(self, args, model):

        if not torch.cuda.is_available():
            raise NotImplementedError('Training on CPU is not supported')

        self.args = args

        self.model = model.cuda()
        self.criterion = CRITERION_REGISTRY[args.criterion](args).cuda()
        self.optimizer = optim.build_optimizer(self.args, self.model.parameters())
        self.lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer)

        if args.amp:
            model, optimizer = amp.initialize(
                    self.model,
                    self.optimizer._optimizer, 
                    opt_level=self.args.amp_level if self.args.amp_level else 'O2',
                    max_loss_scale=2**15,
                    cast_model_outputs=torch.float16
                    )

        if self.args.distributed_world_size > 1:
            self.model = DDP(model)

        self._buffered_stats = defaultdict(lambda: [])
        self._flat_grads = None
        self._num_updates = 0
        self._num_val_iterations = 0
        self._optim_history = None
        self.throughput_meter = TimeMeter()

    def save_checkpoint(self, filename, extra_state):
        """Save all training state in a checkpoint file."""
        if self.args.amp:
            extra_state['amp_state_dict'] = amp.state_dict()
            extra_state['amp_master_params'] = list(amp.master_params(self.optimizer.optimizer))
        if distributed_utils.is_master(self.args):  # only save one checkpoint
            utils.save_state(
                filename, self.args, self.get_model(), self.criterion, self.optimizer,
                self.lr_scheduler, self._num_updates, self._optim_history, extra_state,
            )

    def load_checkpoint(self, filename, load_optim=True):
        """Load all training state from a checkpoint file."""
        extra_state, optim_history, last_optim_state = \
            utils.load_model_state(filename, self.get_model())

        if last_optim_state is not None:
            # rebuild optimizer after loading model, since params may have changed
            #self.optimizer = optim.build_optimizer(self.args, self.model.parameters())
            self.lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer)

            if load_optim:
                self._optim_history = optim_history
                # only reload optimizer and lr_scheduler if they match
                last_optim = self._optim_history[-1]
                if last_optim['criterion_name'] == self.criterion.__class__.__name__:
                    self.lr_scheduler.load_state_dict(last_optim['lr_scheduler_state'])
                    if last_optim['optimizer_name'] == self.optimizer.__class__.__name__:
                        self.optimizer.load_state_dict(last_optim_state)

                self._num_updates = last_optim['num_updates']

        if self.args.amp and extra_state is not None and 'amp_state_dict' in extra_state:
            self.optimizer.optimizer._lazy_init_maybe_master_weights()
            self.optimizer.optimizer._amp_stash.lazy_init_called = True
            self.optimizer.optimizer.load_state_dict(last_optim_state)
            for param, saved_param in zip(amp.master_params(self.optimizer.optimizer), extra_state['amp_master_params']):
                param.data.copy_(saved_param.data)
 
            amp.load_state_dict(extra_state['amp_state_dict'])

        return extra_state

    def train_step(self, sample, update_params=True, last_step=False):
        """Do forward, backward and parameter update."""
        # Set seed based on args.seed and the update number so that we get
        # reproducible results when resuming from checkpoints
        seed = self.args.seed + self.get_num_updates()
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)

        self.model.train()
        if isinstance(self.model, DDP):
            if last_step:
                self.model.disable_allreduce()
            else:
                self.model.enable_allreduce()

        # forward and backward pass
        sample, sample_size = self._prepare_sample(sample)
        loss, oom_fwd = self._forward(sample)

        # If this is a last batch forward pass is skipped on some workers
        # Batch with sample_size 0 is not accounted for in weighted loss
        logging_output = {
            'ntokens': sample['ntokens'] if sample is not None else 0,
            'nsentences': sample['target'].size(0) if sample is not None else 0,
            'loss': utils.item(loss.data) if loss is not None else 0,
            'sample_size': sample_size

        }
        oom_bwd = self._backward(loss)

        # buffer stats and logging outputs
        self._buffered_stats['sample_sizes'].append(sample_size)
        self._buffered_stats['logging_outputs'].append(logging_output)
        self._buffered_stats['ooms_fwd'].append(oom_fwd)
        self._buffered_stats['ooms_bwd'].append(oom_bwd)

        # update parameters 
        if update_params and not last_step:
            # gather logging outputs from all replicas
            sample_sizes = self._buffered_stats['sample_sizes']
            logging_outputs = self._buffered_stats['logging_outputs']
            ooms_fwd = self._buffered_stats['ooms_fwd']
            ooms_bwd = self._buffered_stats['ooms_bwd']
            if self.args.distributed_world_size > 1:
                sample_sizes, logging_outputs, ooms_fwd, ooms_bwd = map(
                    lambda l: list(chain.from_iterable(l)),
                    zip(*distributed_utils.all_gather_list(
                        (sample_sizes, logging_outputs, ooms_fwd, ooms_bwd)
                    ))
                )
            ooms_fwd = sum(ooms_fwd)
            ooms_bwd = sum(ooms_bwd)
            ooms = ooms_fwd + ooms_bwd #this is always <= distributed_world_size

            if ooms == self.args.distributed_world_size:
                print('| WARNING: OOM in all workers, skipping batch')
                self.zero_grad()
                return

            # aggregate stats and logging outputs
            grad_denom = sum(sample_sizes)
            for p in self.model.parameters():
                if p.requires_grad and not p.grad is None:
                    p.grad /= grad_denom

            self._opt()

            # Handle logging
            sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
            ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
            self.throughput_meter.update(ntokens)
            info_log_data = {
                        'tokens/s':self.throughput_meter.avg,
                        'tokens':ntokens,
                        'loss':sum(log.get('loss', 0) for log in logging_outputs) / sample_size / math.log(2)
                        }
            debug_log_data = {
                        'batch_size':sum(log.get('nsentences', 0) for log in logging_outputs),
                        'lr':self.get_lr(),
                        'grad_denom':grad_denom,
                        'updates':1
                        }

            DLLogger.log(step=self._num_updates, data=info_log_data, verbosity=0)
            DLLogger.log(step=self._num_updates, data=debug_log_data, verbosity=1)

            self.clear_buffered_stats()

    def _forward(self, sample):
        loss = None
        oom = 0
        try:
            if sample is not None:
                # calculate loss and sample size
                logits, _ = self.model(**sample['net_input'])
                target = sample['target']
                if not self.args.adaptive_softmax_cutoff:
                    probs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
                else:
                    #TODO: trainig crashes after couple hundred iterations because of unknown
                    #error in the PyTorch's autograd
                    probs, target = self.get_model().decoder.adaptive_softmax(logits, target.view(-1))
                loss = self.criterion(probs, target)
        except RuntimeError as e:
            if not eval and 'out of memory' in str(e):
                print('| WARNING: ran out of memory in worker {}, skipping batch'.format(self.args.distributed_rank), force=True)
                oom = 1
                loss = None
            else:
                raise e
        return loss, oom

    def _backward(self, loss):
        oom = 0
        if loss is not None:
            try:
                if self.args.amp:
                    with amp.scale_loss(loss, self.optimizer._optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

            except RuntimeError as e:
                if 'out of memory' in str(e):
                    print('| WARNING: ran out of memory in worker {}, skipping batch'.format(self.args.distributed_rank), force=True)
                    oom = 1
                    self.zero_grad()
                else:
                    raise e
        return oom

    def _opt(self):
        # take an optimization step
        self.optimizer.step()
        self.zero_grad()
        self._num_updates += 1

        # update learning rate
        self.lr_scheduler.step_update(self._num_updates)

    def valid_step(self, sample):
        """Do forward pass in evaluation mode."""
        self.model.eval()
        self._num_val_iterations += 1
        # forward pass
        sample, sample_size = self._prepare_sample(sample)
        with torch.no_grad():
            loss, oom_fwd = self._forward(sample)
        logging_output = {
            'ntokens': sample['ntokens'] if sample is not None else 0,
            'nsentences': sample['target'].size(0) if sample is not None else 0,
            'sample_size': sample_size
        }
        loss = loss.item() if loss is not None else 0
        assert not oom_fwd, 'Ran out of memory during validation'

        # gather logging outputs from all GPUs
        if self.args.distributed_world_size > 1:
            losses, sample_sizes, logging_outputs = zip(*distributed_utils.all_gather_list(
                (loss, sample_size, logging_output)
            ))
        else:
            losses = [loss]
            sample_sizes = [sample_size]
            logging_outputs = [logging_output]

        # TODO: check when ntokens != sample_size
        ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
        weight = sum(log.get('sample_size', 0) for log in logging_outputs)
        scaled_loss = sum(losses) / weight / math.log(2)

        return scaled_loss 

    def dummy_train_step(self, dummy_batch):
        """Dummy training step for warming caching allocator."""
        self.train_step(dummy_batch, update_params=False)
        self.zero_grad()
        self.clear_buffered_stats()

    def zero_grad(self):
        self.optimizer.zero_grad()

    def clear_buffered_stats(self):
        self._buffered_stats.clear()

    def lr_step(self, epoch, val_loss=None):
        """Adjust the learning rate based on the validation loss."""
        return self.lr_scheduler.step(epoch, val_loss)

    def lr_step_update(self, num_updates):
        """Update the learning rate after each update."""
        return self.lr_scheduler.step_update(num_updates)

    def get_lr(self):
        """Get the current learning rate."""
        return self.optimizer.get_lr()
    
    def get_throughput_meter(self):
        """Get the throughput meter"""
        return self.throughput_meter

    def get_model(self):
        """Get the model replica."""
        return self.model.module if isinstance(self.model, DDP) else self.model

    def get_num_updates(self):
        """Get the number of parameters updates."""
        return self._num_updates

    def _prepare_sample(self, sample):
        if sample is None or len(sample) == 0:
            return None, 0
        return utils.move_to_cuda(sample), sample['ntokens']
示例#2
0
class FlowNMT(nn.Module):
    """
    NMT model with Generative Flow.
    """

    def __init__(self, core: FlowNMTCore):
        super(FlowNMT, self).__init__()
        self.core = core
        self.length_unit = self.core.prior.length_unit
        self.distribured_enabled = False

    def _get_core(self):
        return self.core.module if self.distribured_enabled else self.core

    def sync(self):
        core = self._get_core()
        core.prior.sync()

    def init(self, src_sents, tgt_sents, src_masks, tgt_masks, init_scale=1.0):
        core = self._get_core()
        core.init(src_sents, tgt_sents, src_masks, tgt_masks, init_scale=init_scale)

    def init_posterior(self, src_sents, tgt_sents, src_masks, tgt_masks, init_scale=1.0):
        core = self._get_core()
        core.init_posterior(src_sents, tgt_sents, src_masks, tgt_masks, init_scale=init_scale)

    def init_prior(self, src_sents, tgt_sents, src_masks, tgt_masks, init_scale=1.0):
        core = self._get_core()
        core.init_prior(src_sents, tgt_sents, src_masks, tgt_masks, init_scale=init_scale)

    def reconstruct(self, src_sents: torch.Tensor, tgt_sents: torch.Tensor,
                    src_masks: torch.Tensor, tgt_masks: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        return self._get_core().reconstruct(src_sents, tgt_sents, src_masks, tgt_masks)

    def translate_argmax(self, src_sents: torch.Tensor, src_masks: torch.Tensor,
                         n_tr: int = 1, tau: float = 0.0) -> Tuple[torch.Tensor, torch.Tensor]:
        """

        Args:
            src_sents: Tensor [batch, src_length]
                tensor for source sentences
            src_masks: Tensor [batch, src_length] or None
                tensor for source masks
            n_tr: int (default 1)
                number of translations per sentence per length candidate
            tau: float (default 0.0)
                temperature

        Returns: Tensor1, Tensor2
            Tensor1: tensor for translations [batch, tgt_length]
            Tensor2: lengths [batch]

        """
        return self._get_core().translate_argmax(src_sents, src_masks, n_tr=n_tr, tau=tau)

    def translate_iw(self, src_sents: torch.Tensor, src_masks: torch.Tensor,
                     n_len: int = 1, n_tr: int = 1,
                     tau: float = 0.0, k: int = 1) -> Tuple[torch.Tensor, torch.Tensor]:
        """

        Args:
            src_sents: Tensor [batch, src_length]
                tensor for source sentences
            src_masks: Tensor [batch, src_length] or None
                tensor for source masks
            n_len: int (default 1)
                number of length candidates
            n_tr: int (default 1)
                number of translations per sentence per length candidate
            tau: float (default 0.0)
                temperature
            k: int (default 1)
                number of samples for importance weighted sampling

        Returns: Tensor1, Tensor2
            Tensor1: tensor for translations [batch, tgt_length]
            Tensor2: lengths [batch]

        """
        return self._get_core().translate_iw(src_sents, src_masks, n_len=n_len, n_tr=n_tr,
                                             tau=tau, k=k)

    def translate_sample(self, src_sents: torch.Tensor, src_masks: torch.Tensor,
                         n_len: int = 1, n_tr: int = 1, tau: float = 0.0) -> Tuple[torch.Tensor, torch.Tensor]:
        """

        Args:
            src_sents: Tensor [batch, src_length]
                tensor for source sentences
            src_masks: Tensor [batch, src_length] or None
                tensor for source masks
            n_len: int (default 1)
                number of length candidates
            n_tr: int (default 1)
                number of translations per sentence per length candidate
            tau: float (default 0.0)
                temperature

        Returns: Tensor1, Tensor2
            Tensor1: tensor for translations [batch * n_len * n_tr, tgt_length]
            Tensor2: lengths [batch * n_len * n_tr]

        """
        return self._get_core().translate_sample(src_sents, src_masks, n_len=n_len, n_tr=n_tr, tau=tau)

    def reconstruct_error(self, src_sents: torch.Tensor, tgt_sents: torch,
                          src_masks: torch.Tensor, tgt_masks: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """

        Args:
            src_sents: Tensor [batch, src_length]
                tensor for source sentences
            tgt_sents: Tensor [batch, tgt_length]
                tensor for target sentences
            src_masks: Tensor [batch, src_length] or None
                tensor for source masks
            tgt_masks: Tensor [batch, tgt_length] or None
                tensor for target masks

        Returns: Tensor1, Tensor2
            Tensor1: reconstruction error [batch]
            Tensor2: length loss [batch]

        """
        return self.core(src_sents, tgt_sents, src_masks, tgt_masks, only_recon_loss=True)

    def loss(self, src_sents: torch.Tensor, tgt_sents: torch,
             src_masks: torch.Tensor, tgt_masks: torch.Tensor,
             nsamples: int = 1, eval=False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """

        Args:
            src_sents: Tensor [batch, src_length]
                tensor for source sentences
            tgt_sents: Tensor [batch, tgt_length]
                tensor for target sentences
            src_masks: Tensor [batch, src_length] or None
                tensor for source masks
            tgt_masks: Tensor [batch, tgt_length] or None
                tensor for target masks
            nsamples: int
                number of samples
            eval: bool
                if eval, turn off distributed mode

        Returns: Tensor1, Tensor2, Tensor3
            Tensor1: reconstruction error [batch]
            Tensor2: KL [batch]
            Tensor3: length loss [batch]

        """
        core = self._get_core() if eval else self.core
        return core(src_sents, tgt_sents, src_masks, tgt_masks, nsamples=nsamples)

    def init_distributed(self, rank, local_rank):
        assert not self.distribured_enabled
        self.distribured_enabled = True
        print("Initializing Distributed, rank {}, local rank {}".format(rank, local_rank))
        dist.init_process_group(backend='nccl', rank=rank)
        torch.cuda.set_device(local_rank)
        self.core = DistributedDataParallel(self.core)

    def sync_params(self):
        assert self.distribured_enabled
        core = self._get_core()
        flat_dist_call([param.data for param in core.parameters()], dist.all_reduce)
        self.core.needs_refresh = True

    def enable_allreduce(self):
        assert self.distribured_enabled
        self.core.enable_allreduce()

    def disable_allreduce(self):
        assert self.distribured_enabled
        self.core.disable_allreduce()

    def save(self, model_path):
        model = {'core': self._get_core().state_dict()}
        model_name = os.path.join(model_path, 'model.pt')
        torch.save(model, model_name)

    def save_core(self, path):
        core = self._get_core()
        model = {'prior': core.prior.state_dict(),
                 'encoder': core.encoder.state_dict(),
                 'decoder': core.decoder.state_dict(),
                 'posterior': core.posterior.state_dict()}
        torch.save(model, path)

    def load_core(self, path, device, load_prior=True):
        model = torch.load(path, map_location=device)
        core = self._get_core()
        core.posterior.load_state_dict(model['posterior'])
        core.encoder.load_state_dict(model['encoder'])
        core.decoder.load_state_dict(model['decoder'])
        if load_prior:
            core.prior.load_state_dict(model['prior'])

    @classmethod
    def load(cls, model_path, device):
        params = json.load(open(os.path.join(model_path, 'config.json'), 'r'))
        flownmt = FlowNMT.from_params(params).to(device)
        model_name = os.path.join(model_path, 'model.pt')
        model = torch.load(model_name, map_location=device)
        flownmt.core.load_state_dict(model['core'])
        return flownmt

    @classmethod
    def from_params(cls, params: Dict) -> "FlowNMT":
        src_vocab_size = params.pop('src_vocab_size')
        tgt_vocab_size = params.pop('tgt_vocab_size')
        embed_dim = params.pop('embed_dim')
        latent_dim = params.pop('latent_dim')
        hidden_size = params.pop('hidden_size')
        max_src_length = params.pop('max_src_length')
        max_tgt_length = params.pop('max_tgt_length')
        src_pad_idx = params.pop('src_pad_idx')
        tgt_pad_idx = params.pop('tgt_pad_idx')

        share_embed = params.pop('share_embed')
        tie_weights = params.pop('tie_weights')

        # prior
        prior_params = params.pop('prior')
        prior_params['flow']['features'] = latent_dim
        prior_params['flow']['src_features'] = latent_dim
        prior_params['length_predictor']['features'] = latent_dim
        prior_params['length_predictor']['max_src_length'] = max_src_length
        prior = Prior.by_name(prior_params.pop('type')).from_params(prior_params)
        # eocoder
        encoder_params = params.pop('encoder')
        encoder_params['vocab_size'] = src_vocab_size
        encoder_params['embed_dim'] = embed_dim
        encoder_params['padding_idx'] = src_pad_idx
        encoder_params['latent_dim'] = latent_dim
        encoder_params['hidden_size'] = hidden_size
        encoder = Encoder.by_name(encoder_params.pop('type')).from_params(encoder_params)
        # posterior
        posterior_params = params.pop('posterior')
        posterior_params['vocab_size'] = tgt_vocab_size
        posterior_params['embed_dim'] = embed_dim
        posterior_params['padding_idx'] = tgt_pad_idx
        posterior_params['latent_dim'] = latent_dim
        posterior_params['hidden_size'] = hidden_size
        _shared_embed = encoder.embed if share_embed else None
        posterior_params['_shared_embed'] = _shared_embed
        posterior = Posterior.by_name(posterior_params.pop('type')).from_params(posterior_params)
        # decoder
        decoder_params = params.pop('decoder')
        decoder_params['vocab_size'] = tgt_vocab_size
        decoder_params['latent_dim'] = latent_dim
        decoder_params['hidden_size'] = hidden_size
        _shared_weight = posterior.tgt_embed.weight if tie_weights else None
        decoder_params['_shared_weight'] = _shared_weight
        decoder = Decoder.by_name(decoder_params.pop('type')).from_params(decoder_params)

        return FlowNMT(FlowNMTCore(encoder, prior, posterior, decoder))
示例#3
0
文件: wolf.py 项目: TRUMANCFY/wolf
class WolfModel(nn.Module):
    """
    Variational Auto-Encoding Generative Flow
    """
    def __init__(self, core: WolfCore):
        super(WolfModel, self).__init__()
        self.core = core
        self.distribured_enabled = False

    def _get_core(self):
        return self.core.module if self.distribured_enabled else self.core

    def sync(self):
        core = self._get_core()
        core.sync()

    def init(self, x, y=None, init_scale=1.0):
        core = self._get_core()
        core.init(x, y=y, init_scale=init_scale)

    def init_distributed(self, rank, local_rank):
        assert not self.distribured_enabled
        self.distribured_enabled = True
        print("Initializing Distributed, rank {}, local rank {}".format(
            rank, local_rank))
        dist.init_process_group(backend='nccl', rank=rank)
        torch.cuda.set_device(local_rank)
        self.core = DistributedDataParallel(convert_syncbn_model(self.core))

    def enable_allreduce(self):
        assert self.distribured_enabled
        self.core.enable_allreduce()

    def disable_allreduce(self):
        assert self.distribured_enabled
        self.core.disable_allreduce()

    def encode_global(self, data, y=None, n_bits=8, nsamples=1, random=False):
        """

        Args:
            data: Tensor [batch, channels, height, width]
                input data
            y: Tensor or None
                class labels for x or None.
            n_bits: int (default 8)
                number of bits for image data.
            nsamples: int (default 1)
                number of samples for each image.
            random: bool (default False)
                incorporating randomness.

        Returns: Tensor
            tensor for global encoding [batch, nsamples, dim] or None

        """
        return self._get_core().encode_global(data,
                                              y=y,
                                              n_bits=n_bits,
                                              nsamples=nsamples,
                                              random=random)

    def encode(self,
               data,
               y=None,
               n_bits=8,
               nsamples=1,
               random=False) -> Tuple[torch.Tensor, torch.Tensor]:
        """

        Args:
            data: Tensor [batch, channels, height, width]
                input data
            y: Tensor or None
                class labels for x or None.
            n_bits: int (default 8)
                number of bits for image data.
            nsamples: int (default 1)
                number of samples for each image.
            random: bool (default False)
                incorporating randomness.

        Returns: Tensor1, Tensor2
            Tensor1: epsilon [batch, channels, height width]
            Tensor2: z [batch, dim] or None

        """
        return self._get_core().encode(data,
                                       y=y,
                                       n_bits=n_bits,
                                       nsamples=nsamples,
                                       random=random)

    def decode(self, epsilon, z=None, n_bits=8) -> torch.Tensor:
        """

        Args:
            epsilon: Tensor [batch, channels, height, width]
                epslion for generation
            z: Tensor or None [batch, dim]
                conditional input
            n_bits: int (default 8)
                number of bits for image data.

        Returns: generated tensor [nums, channels, height, width]

        """
        return self._get_core().decode(epsilon, z=z, n_bits=n_bits)

    def decode_with_attn(self, epsilon, z=None, n_bits=8):
        return self._get_core().decode_with_attn(epsilon, z=z, n_bits=n_bits)

    def synthesize(self,
                   nums,
                   image_size,
                   tau=1.0,
                   n_bits=8,
                   device=torch.device('cpu')) -> torch.Tensor:
        """

        Args:
            nums: int
                number of synthesis
            image_size: size of tuple
                the size of the synthesized images with shape [channels, height, width]
            tau: float (default 1.0)
                temperature
            n_bits: int (default 8)
                number of bits for image data.
            device: torch.device
                device to store the synthesis

        Returns: generated tensor [nums, channels, height, width]

        """
        return self._get_core().synthesize(nums,
                                           image_size,
                                           tau=tau,
                                           n_bits=n_bits,
                                           device=device)

    def loss(self, data, y=None, n_bits=8, nsamples=1) -> \
            Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """

        Args:
            data: Tensor [batch, channels, height, width]
                input data.
            y: Tensor or None
                class labels for x or None.
            n_bits: int (default 8)
                number of bits for image data.
            nsamples: int (default 1)
                number of samples for compute the loss.

        Returns: Tensor1, Tensor2, Tensor3
            Tensor1: generation loss [batch]
            Tensor2: KL [batch]
            Tensor3: dequantization loss [batch]

        """
        core = self._get_core() if not self.training else self.core
        return core(data, y=y, n_bits=n_bits, nsamples=nsamples)

    def loss_attn(self, data, y=None, n_bits=8, nsamples=1) -> \
            Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """

        Args:
            data: Tensor [batch, channels, height, width]
                input data.
            y: Tensor or None
                class labels for x or None.
            n_bits: int (default 8)
                number of bits for image data.
            nsamples: int (default 1)
                number of samples for compute the loss.

        Returns: Tensor1, Tensor2, Tensor3
            Tensor1: generation loss [batch]
            Tensor2: KL [batch]
            Tensor3: dequantization loss [batch]

        """
        core = self._get_core() if not self.training else self.core
        return core.forward_attn(data, y=y, n_bits=n_bits, nsamples=nsamples)

    def to_device(self, device):
        assert not self.distribured_enabled
        self.core.discriminator.to_device(device)
        return self.to(device)

    def save(self, model_path, version=None):
        model = {'core': self._get_core().state_dict()}

        if version is not None:
            model_name = os.path.join(model_path, 'model{}.pt'.format(version))
        else:
            model_name = os.path.join(model_path, 'model.pt')

        torch.save(model, model_name)

    @classmethod
    def load(cls, model_path, device, version=None):
        params = json.load(open(os.path.join(model_path, 'config.json'), 'r'))
        flowae = WolfModel.from_params(params).to_device(device)
        if version != None:
            model_name = os.path.join(model_path, 'model{}.pt'.format(version))
        else:
            model_name = os.path.join(model_path, 'model.pt')
        model = torch.load(model_name, map_location=device)
        flowae.core.load_state_dict(model['core'])
        return flowae

    @classmethod
    def from_params(cls, params: Dict) -> "WolfModel":
        # discriminator
        disc_params = params.pop('discriminator')
        discriminator = Discriminator.by_name(
            disc_params.pop('type')).from_params(disc_params)
        # dequantizer
        dequant_params = params.pop('dequantizer')
        dequantizer = DeQuantizer.by_name(
            dequant_params.pop('type')).from_params(dequant_params)
        # generator
        generator_params = params.pop('generator')
        generator = Generator.from_params(generator_params)

        return WolfModel(WolfCore(generator, discriminator, dequantizer))