class DDPTrainer(object): """Main class for data parallel training. This class supports data parallel training, where multiple workers each have a full model replica and gradients are accumulated synchronously via torch.distributed.all_reduce. """ def __init__(self, args, model): if not torch.cuda.is_available(): raise NotImplementedError('Training on CPU is not supported') self.args = args self.model = model.cuda() self.criterion = CRITERION_REGISTRY[args.criterion](args).cuda() self.optimizer = optim.build_optimizer(self.args, self.model.parameters()) self.lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer) if args.amp: model, optimizer = amp.initialize( self.model, self.optimizer._optimizer, opt_level=self.args.amp_level if self.args.amp_level else 'O2', max_loss_scale=2**15, cast_model_outputs=torch.float16 ) if self.args.distributed_world_size > 1: self.model = DDP(model) self._buffered_stats = defaultdict(lambda: []) self._flat_grads = None self._num_updates = 0 self._num_val_iterations = 0 self._optim_history = None self.throughput_meter = TimeMeter() def save_checkpoint(self, filename, extra_state): """Save all training state in a checkpoint file.""" if self.args.amp: extra_state['amp_state_dict'] = amp.state_dict() extra_state['amp_master_params'] = list(amp.master_params(self.optimizer.optimizer)) if distributed_utils.is_master(self.args): # only save one checkpoint utils.save_state( filename, self.args, self.get_model(), self.criterion, self.optimizer, self.lr_scheduler, self._num_updates, self._optim_history, extra_state, ) def load_checkpoint(self, filename, load_optim=True): """Load all training state from a checkpoint file.""" extra_state, optim_history, last_optim_state = \ utils.load_model_state(filename, self.get_model()) if last_optim_state is not None: # rebuild optimizer after loading model, since params may have changed #self.optimizer = optim.build_optimizer(self.args, self.model.parameters()) self.lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer) if load_optim: self._optim_history = optim_history # only reload optimizer and lr_scheduler if they match last_optim = self._optim_history[-1] if last_optim['criterion_name'] == self.criterion.__class__.__name__: self.lr_scheduler.load_state_dict(last_optim['lr_scheduler_state']) if last_optim['optimizer_name'] == self.optimizer.__class__.__name__: self.optimizer.load_state_dict(last_optim_state) self._num_updates = last_optim['num_updates'] if self.args.amp and extra_state is not None and 'amp_state_dict' in extra_state: self.optimizer.optimizer._lazy_init_maybe_master_weights() self.optimizer.optimizer._amp_stash.lazy_init_called = True self.optimizer.optimizer.load_state_dict(last_optim_state) for param, saved_param in zip(amp.master_params(self.optimizer.optimizer), extra_state['amp_master_params']): param.data.copy_(saved_param.data) amp.load_state_dict(extra_state['amp_state_dict']) return extra_state def train_step(self, sample, update_params=True, last_step=False): """Do forward, backward and parameter update.""" # Set seed based on args.seed and the update number so that we get # reproducible results when resuming from checkpoints seed = self.args.seed + self.get_num_updates() torch.manual_seed(seed) torch.cuda.manual_seed(seed) self.model.train() if isinstance(self.model, DDP): if last_step: self.model.disable_allreduce() else: self.model.enable_allreduce() # forward and backward pass sample, sample_size = self._prepare_sample(sample) loss, oom_fwd = self._forward(sample) # If this is a last batch forward pass is skipped on some workers # Batch with sample_size 0 is not accounted for in weighted loss logging_output = { 'ntokens': sample['ntokens'] if sample is not None else 0, 'nsentences': sample['target'].size(0) if sample is not None else 0, 'loss': utils.item(loss.data) if loss is not None else 0, 'sample_size': sample_size } oom_bwd = self._backward(loss) # buffer stats and logging outputs self._buffered_stats['sample_sizes'].append(sample_size) self._buffered_stats['logging_outputs'].append(logging_output) self._buffered_stats['ooms_fwd'].append(oom_fwd) self._buffered_stats['ooms_bwd'].append(oom_bwd) # update parameters if update_params and not last_step: # gather logging outputs from all replicas sample_sizes = self._buffered_stats['sample_sizes'] logging_outputs = self._buffered_stats['logging_outputs'] ooms_fwd = self._buffered_stats['ooms_fwd'] ooms_bwd = self._buffered_stats['ooms_bwd'] if self.args.distributed_world_size > 1: sample_sizes, logging_outputs, ooms_fwd, ooms_bwd = map( lambda l: list(chain.from_iterable(l)), zip(*distributed_utils.all_gather_list( (sample_sizes, logging_outputs, ooms_fwd, ooms_bwd) )) ) ooms_fwd = sum(ooms_fwd) ooms_bwd = sum(ooms_bwd) ooms = ooms_fwd + ooms_bwd #this is always <= distributed_world_size if ooms == self.args.distributed_world_size: print('| WARNING: OOM in all workers, skipping batch') self.zero_grad() return # aggregate stats and logging outputs grad_denom = sum(sample_sizes) for p in self.model.parameters(): if p.requires_grad and not p.grad is None: p.grad /= grad_denom self._opt() # Handle logging sample_size = sum(log.get('sample_size', 0) for log in logging_outputs) ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) self.throughput_meter.update(ntokens) info_log_data = { 'tokens/s':self.throughput_meter.avg, 'tokens':ntokens, 'loss':sum(log.get('loss', 0) for log in logging_outputs) / sample_size / math.log(2) } debug_log_data = { 'batch_size':sum(log.get('nsentences', 0) for log in logging_outputs), 'lr':self.get_lr(), 'grad_denom':grad_denom, 'updates':1 } DLLogger.log(step=self._num_updates, data=info_log_data, verbosity=0) DLLogger.log(step=self._num_updates, data=debug_log_data, verbosity=1) self.clear_buffered_stats() def _forward(self, sample): loss = None oom = 0 try: if sample is not None: # calculate loss and sample size logits, _ = self.model(**sample['net_input']) target = sample['target'] if not self.args.adaptive_softmax_cutoff: probs = F.log_softmax(logits, dim=-1, dtype=torch.float32) else: #TODO: trainig crashes after couple hundred iterations because of unknown #error in the PyTorch's autograd probs, target = self.get_model().decoder.adaptive_softmax(logits, target.view(-1)) loss = self.criterion(probs, target) except RuntimeError as e: if not eval and 'out of memory' in str(e): print('| WARNING: ran out of memory in worker {}, skipping batch'.format(self.args.distributed_rank), force=True) oom = 1 loss = None else: raise e return loss, oom def _backward(self, loss): oom = 0 if loss is not None: try: if self.args.amp: with amp.scale_loss(loss, self.optimizer._optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() except RuntimeError as e: if 'out of memory' in str(e): print('| WARNING: ran out of memory in worker {}, skipping batch'.format(self.args.distributed_rank), force=True) oom = 1 self.zero_grad() else: raise e return oom def _opt(self): # take an optimization step self.optimizer.step() self.zero_grad() self._num_updates += 1 # update learning rate self.lr_scheduler.step_update(self._num_updates) def valid_step(self, sample): """Do forward pass in evaluation mode.""" self.model.eval() self._num_val_iterations += 1 # forward pass sample, sample_size = self._prepare_sample(sample) with torch.no_grad(): loss, oom_fwd = self._forward(sample) logging_output = { 'ntokens': sample['ntokens'] if sample is not None else 0, 'nsentences': sample['target'].size(0) if sample is not None else 0, 'sample_size': sample_size } loss = loss.item() if loss is not None else 0 assert not oom_fwd, 'Ran out of memory during validation' # gather logging outputs from all GPUs if self.args.distributed_world_size > 1: losses, sample_sizes, logging_outputs = zip(*distributed_utils.all_gather_list( (loss, sample_size, logging_output) )) else: losses = [loss] sample_sizes = [sample_size] logging_outputs = [logging_output] # TODO: check when ntokens != sample_size ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) weight = sum(log.get('sample_size', 0) for log in logging_outputs) scaled_loss = sum(losses) / weight / math.log(2) return scaled_loss def dummy_train_step(self, dummy_batch): """Dummy training step for warming caching allocator.""" self.train_step(dummy_batch, update_params=False) self.zero_grad() self.clear_buffered_stats() def zero_grad(self): self.optimizer.zero_grad() def clear_buffered_stats(self): self._buffered_stats.clear() def lr_step(self, epoch, val_loss=None): """Adjust the learning rate based on the validation loss.""" return self.lr_scheduler.step(epoch, val_loss) def lr_step_update(self, num_updates): """Update the learning rate after each update.""" return self.lr_scheduler.step_update(num_updates) def get_lr(self): """Get the current learning rate.""" return self.optimizer.get_lr() def get_throughput_meter(self): """Get the throughput meter""" return self.throughput_meter def get_model(self): """Get the model replica.""" return self.model.module if isinstance(self.model, DDP) else self.model def get_num_updates(self): """Get the number of parameters updates.""" return self._num_updates def _prepare_sample(self, sample): if sample is None or len(sample) == 0: return None, 0 return utils.move_to_cuda(sample), sample['ntokens']
class FlowNMT(nn.Module): """ NMT model with Generative Flow. """ def __init__(self, core: FlowNMTCore): super(FlowNMT, self).__init__() self.core = core self.length_unit = self.core.prior.length_unit self.distribured_enabled = False def _get_core(self): return self.core.module if self.distribured_enabled else self.core def sync(self): core = self._get_core() core.prior.sync() def init(self, src_sents, tgt_sents, src_masks, tgt_masks, init_scale=1.0): core = self._get_core() core.init(src_sents, tgt_sents, src_masks, tgt_masks, init_scale=init_scale) def init_posterior(self, src_sents, tgt_sents, src_masks, tgt_masks, init_scale=1.0): core = self._get_core() core.init_posterior(src_sents, tgt_sents, src_masks, tgt_masks, init_scale=init_scale) def init_prior(self, src_sents, tgt_sents, src_masks, tgt_masks, init_scale=1.0): core = self._get_core() core.init_prior(src_sents, tgt_sents, src_masks, tgt_masks, init_scale=init_scale) def reconstruct(self, src_sents: torch.Tensor, tgt_sents: torch.Tensor, src_masks: torch.Tensor, tgt_masks: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: return self._get_core().reconstruct(src_sents, tgt_sents, src_masks, tgt_masks) def translate_argmax(self, src_sents: torch.Tensor, src_masks: torch.Tensor, n_tr: int = 1, tau: float = 0.0) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: src_sents: Tensor [batch, src_length] tensor for source sentences src_masks: Tensor [batch, src_length] or None tensor for source masks n_tr: int (default 1) number of translations per sentence per length candidate tau: float (default 0.0) temperature Returns: Tensor1, Tensor2 Tensor1: tensor for translations [batch, tgt_length] Tensor2: lengths [batch] """ return self._get_core().translate_argmax(src_sents, src_masks, n_tr=n_tr, tau=tau) def translate_iw(self, src_sents: torch.Tensor, src_masks: torch.Tensor, n_len: int = 1, n_tr: int = 1, tau: float = 0.0, k: int = 1) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: src_sents: Tensor [batch, src_length] tensor for source sentences src_masks: Tensor [batch, src_length] or None tensor for source masks n_len: int (default 1) number of length candidates n_tr: int (default 1) number of translations per sentence per length candidate tau: float (default 0.0) temperature k: int (default 1) number of samples for importance weighted sampling Returns: Tensor1, Tensor2 Tensor1: tensor for translations [batch, tgt_length] Tensor2: lengths [batch] """ return self._get_core().translate_iw(src_sents, src_masks, n_len=n_len, n_tr=n_tr, tau=tau, k=k) def translate_sample(self, src_sents: torch.Tensor, src_masks: torch.Tensor, n_len: int = 1, n_tr: int = 1, tau: float = 0.0) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: src_sents: Tensor [batch, src_length] tensor for source sentences src_masks: Tensor [batch, src_length] or None tensor for source masks n_len: int (default 1) number of length candidates n_tr: int (default 1) number of translations per sentence per length candidate tau: float (default 0.0) temperature Returns: Tensor1, Tensor2 Tensor1: tensor for translations [batch * n_len * n_tr, tgt_length] Tensor2: lengths [batch * n_len * n_tr] """ return self._get_core().translate_sample(src_sents, src_masks, n_len=n_len, n_tr=n_tr, tau=tau) def reconstruct_error(self, src_sents: torch.Tensor, tgt_sents: torch, src_masks: torch.Tensor, tgt_masks: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: src_sents: Tensor [batch, src_length] tensor for source sentences tgt_sents: Tensor [batch, tgt_length] tensor for target sentences src_masks: Tensor [batch, src_length] or None tensor for source masks tgt_masks: Tensor [batch, tgt_length] or None tensor for target masks Returns: Tensor1, Tensor2 Tensor1: reconstruction error [batch] Tensor2: length loss [batch] """ return self.core(src_sents, tgt_sents, src_masks, tgt_masks, only_recon_loss=True) def loss(self, src_sents: torch.Tensor, tgt_sents: torch, src_masks: torch.Tensor, tgt_masks: torch.Tensor, nsamples: int = 1, eval=False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: src_sents: Tensor [batch, src_length] tensor for source sentences tgt_sents: Tensor [batch, tgt_length] tensor for target sentences src_masks: Tensor [batch, src_length] or None tensor for source masks tgt_masks: Tensor [batch, tgt_length] or None tensor for target masks nsamples: int number of samples eval: bool if eval, turn off distributed mode Returns: Tensor1, Tensor2, Tensor3 Tensor1: reconstruction error [batch] Tensor2: KL [batch] Tensor3: length loss [batch] """ core = self._get_core() if eval else self.core return core(src_sents, tgt_sents, src_masks, tgt_masks, nsamples=nsamples) def init_distributed(self, rank, local_rank): assert not self.distribured_enabled self.distribured_enabled = True print("Initializing Distributed, rank {}, local rank {}".format(rank, local_rank)) dist.init_process_group(backend='nccl', rank=rank) torch.cuda.set_device(local_rank) self.core = DistributedDataParallel(self.core) def sync_params(self): assert self.distribured_enabled core = self._get_core() flat_dist_call([param.data for param in core.parameters()], dist.all_reduce) self.core.needs_refresh = True def enable_allreduce(self): assert self.distribured_enabled self.core.enable_allreduce() def disable_allreduce(self): assert self.distribured_enabled self.core.disable_allreduce() def save(self, model_path): model = {'core': self._get_core().state_dict()} model_name = os.path.join(model_path, 'model.pt') torch.save(model, model_name) def save_core(self, path): core = self._get_core() model = {'prior': core.prior.state_dict(), 'encoder': core.encoder.state_dict(), 'decoder': core.decoder.state_dict(), 'posterior': core.posterior.state_dict()} torch.save(model, path) def load_core(self, path, device, load_prior=True): model = torch.load(path, map_location=device) core = self._get_core() core.posterior.load_state_dict(model['posterior']) core.encoder.load_state_dict(model['encoder']) core.decoder.load_state_dict(model['decoder']) if load_prior: core.prior.load_state_dict(model['prior']) @classmethod def load(cls, model_path, device): params = json.load(open(os.path.join(model_path, 'config.json'), 'r')) flownmt = FlowNMT.from_params(params).to(device) model_name = os.path.join(model_path, 'model.pt') model = torch.load(model_name, map_location=device) flownmt.core.load_state_dict(model['core']) return flownmt @classmethod def from_params(cls, params: Dict) -> "FlowNMT": src_vocab_size = params.pop('src_vocab_size') tgt_vocab_size = params.pop('tgt_vocab_size') embed_dim = params.pop('embed_dim') latent_dim = params.pop('latent_dim') hidden_size = params.pop('hidden_size') max_src_length = params.pop('max_src_length') max_tgt_length = params.pop('max_tgt_length') src_pad_idx = params.pop('src_pad_idx') tgt_pad_idx = params.pop('tgt_pad_idx') share_embed = params.pop('share_embed') tie_weights = params.pop('tie_weights') # prior prior_params = params.pop('prior') prior_params['flow']['features'] = latent_dim prior_params['flow']['src_features'] = latent_dim prior_params['length_predictor']['features'] = latent_dim prior_params['length_predictor']['max_src_length'] = max_src_length prior = Prior.by_name(prior_params.pop('type')).from_params(prior_params) # eocoder encoder_params = params.pop('encoder') encoder_params['vocab_size'] = src_vocab_size encoder_params['embed_dim'] = embed_dim encoder_params['padding_idx'] = src_pad_idx encoder_params['latent_dim'] = latent_dim encoder_params['hidden_size'] = hidden_size encoder = Encoder.by_name(encoder_params.pop('type')).from_params(encoder_params) # posterior posterior_params = params.pop('posterior') posterior_params['vocab_size'] = tgt_vocab_size posterior_params['embed_dim'] = embed_dim posterior_params['padding_idx'] = tgt_pad_idx posterior_params['latent_dim'] = latent_dim posterior_params['hidden_size'] = hidden_size _shared_embed = encoder.embed if share_embed else None posterior_params['_shared_embed'] = _shared_embed posterior = Posterior.by_name(posterior_params.pop('type')).from_params(posterior_params) # decoder decoder_params = params.pop('decoder') decoder_params['vocab_size'] = tgt_vocab_size decoder_params['latent_dim'] = latent_dim decoder_params['hidden_size'] = hidden_size _shared_weight = posterior.tgt_embed.weight if tie_weights else None decoder_params['_shared_weight'] = _shared_weight decoder = Decoder.by_name(decoder_params.pop('type')).from_params(decoder_params) return FlowNMT(FlowNMTCore(encoder, prior, posterior, decoder))
class WolfModel(nn.Module): """ Variational Auto-Encoding Generative Flow """ def __init__(self, core: WolfCore): super(WolfModel, self).__init__() self.core = core self.distribured_enabled = False def _get_core(self): return self.core.module if self.distribured_enabled else self.core def sync(self): core = self._get_core() core.sync() def init(self, x, y=None, init_scale=1.0): core = self._get_core() core.init(x, y=y, init_scale=init_scale) def init_distributed(self, rank, local_rank): assert not self.distribured_enabled self.distribured_enabled = True print("Initializing Distributed, rank {}, local rank {}".format( rank, local_rank)) dist.init_process_group(backend='nccl', rank=rank) torch.cuda.set_device(local_rank) self.core = DistributedDataParallel(convert_syncbn_model(self.core)) def enable_allreduce(self): assert self.distribured_enabled self.core.enable_allreduce() def disable_allreduce(self): assert self.distribured_enabled self.core.disable_allreduce() def encode_global(self, data, y=None, n_bits=8, nsamples=1, random=False): """ Args: data: Tensor [batch, channels, height, width] input data y: Tensor or None class labels for x or None. n_bits: int (default 8) number of bits for image data. nsamples: int (default 1) number of samples for each image. random: bool (default False) incorporating randomness. Returns: Tensor tensor for global encoding [batch, nsamples, dim] or None """ return self._get_core().encode_global(data, y=y, n_bits=n_bits, nsamples=nsamples, random=random) def encode(self, data, y=None, n_bits=8, nsamples=1, random=False) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: data: Tensor [batch, channels, height, width] input data y: Tensor or None class labels for x or None. n_bits: int (default 8) number of bits for image data. nsamples: int (default 1) number of samples for each image. random: bool (default False) incorporating randomness. Returns: Tensor1, Tensor2 Tensor1: epsilon [batch, channels, height width] Tensor2: z [batch, dim] or None """ return self._get_core().encode(data, y=y, n_bits=n_bits, nsamples=nsamples, random=random) def decode(self, epsilon, z=None, n_bits=8) -> torch.Tensor: """ Args: epsilon: Tensor [batch, channels, height, width] epslion for generation z: Tensor or None [batch, dim] conditional input n_bits: int (default 8) number of bits for image data. Returns: generated tensor [nums, channels, height, width] """ return self._get_core().decode(epsilon, z=z, n_bits=n_bits) def decode_with_attn(self, epsilon, z=None, n_bits=8): return self._get_core().decode_with_attn(epsilon, z=z, n_bits=n_bits) def synthesize(self, nums, image_size, tau=1.0, n_bits=8, device=torch.device('cpu')) -> torch.Tensor: """ Args: nums: int number of synthesis image_size: size of tuple the size of the synthesized images with shape [channels, height, width] tau: float (default 1.0) temperature n_bits: int (default 8) number of bits for image data. device: torch.device device to store the synthesis Returns: generated tensor [nums, channels, height, width] """ return self._get_core().synthesize(nums, image_size, tau=tau, n_bits=n_bits, device=device) def loss(self, data, y=None, n_bits=8, nsamples=1) -> \ Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: data: Tensor [batch, channels, height, width] input data. y: Tensor or None class labels for x or None. n_bits: int (default 8) number of bits for image data. nsamples: int (default 1) number of samples for compute the loss. Returns: Tensor1, Tensor2, Tensor3 Tensor1: generation loss [batch] Tensor2: KL [batch] Tensor3: dequantization loss [batch] """ core = self._get_core() if not self.training else self.core return core(data, y=y, n_bits=n_bits, nsamples=nsamples) def loss_attn(self, data, y=None, n_bits=8, nsamples=1) -> \ Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: data: Tensor [batch, channels, height, width] input data. y: Tensor or None class labels for x or None. n_bits: int (default 8) number of bits for image data. nsamples: int (default 1) number of samples for compute the loss. Returns: Tensor1, Tensor2, Tensor3 Tensor1: generation loss [batch] Tensor2: KL [batch] Tensor3: dequantization loss [batch] """ core = self._get_core() if not self.training else self.core return core.forward_attn(data, y=y, n_bits=n_bits, nsamples=nsamples) def to_device(self, device): assert not self.distribured_enabled self.core.discriminator.to_device(device) return self.to(device) def save(self, model_path, version=None): model = {'core': self._get_core().state_dict()} if version is not None: model_name = os.path.join(model_path, 'model{}.pt'.format(version)) else: model_name = os.path.join(model_path, 'model.pt') torch.save(model, model_name) @classmethod def load(cls, model_path, device, version=None): params = json.load(open(os.path.join(model_path, 'config.json'), 'r')) flowae = WolfModel.from_params(params).to_device(device) if version != None: model_name = os.path.join(model_path, 'model{}.pt'.format(version)) else: model_name = os.path.join(model_path, 'model.pt') model = torch.load(model_name, map_location=device) flowae.core.load_state_dict(model['core']) return flowae @classmethod def from_params(cls, params: Dict) -> "WolfModel": # discriminator disc_params = params.pop('discriminator') discriminator = Discriminator.by_name( disc_params.pop('type')).from_params(disc_params) # dequantizer dequant_params = params.pop('dequantizer') dequantizer = DeQuantizer.by_name( dequant_params.pop('type')).from_params(dequant_params) # generator generator_params = params.pop('generator') generator = Generator.from_params(generator_params) return WolfModel(WolfCore(generator, discriminator, dequantizer))