def __iter__(self): gnmt_print(key=mlperf_log.INPUT_ORDER, sync=False) rng = self.init_rng() global_bs = self.global_batch_size indices = [] for bid in range(self.num_buckets): # random shuffle within current bucket perm = torch.randperm(len(self.buckets[bid]), generator=rng) bucket_indices = self.buckets[bid][perm] # make bucket_indices evenly divisible by global batch size length = len(bucket_indices) // global_bs * global_bs bucket_indices = bucket_indices[:length] assert len(bucket_indices) % self.global_batch_size == 0 # add samples from current bucket to indices for current epoch indices.append(bucket_indices) indices = torch.cat(indices) assert len(indices) % self.global_batch_size == 0 # perform global reshuffle of all global batches indices = self.reshuffle_batches(indices, rng) # distribute batches to individual workers indices = self.distribute_batches(indices) return iter(indices)
def __iter__(self): gnmt_print(key=mlperf_log.INPUT_ORDER, sync=False) rng = self.init_rng() # generate permutation indices = torch.randperm(self.data_len, generator=rng) # make indices evenly divisible by (batch_size * world_size) indices = indices[:self.num_samples] # splits the dataset into chunks of 'self.shard_size' global batches # each, sorts by (src + tgt) sequence length within each chunk, # reshuffles all global batches shard_size = self.global_batch_size * self.shard_size nshards = (self.num_samples + shard_size - 1) // shard_size lengths = self.dataset.lengths[indices] shards = [indices[i * shard_size:(i+1) * shard_size] for i in range(nshards)] len_shards = [lengths[i * shard_size:(i+1) * shard_size] for i in range(nshards)] # sort by (src + tgt) sequence length within each shard indices = [] for len_shard in len_shards: _, ind = len_shard.sort() indices.append(ind) output = tuple(shard[idx] for shard, idx in zip(shards, indices)) # build batches indices = torch.cat(output) # perform global reshuffle of all global batches indices = self.reshuffle_batches(indices, rng) # distribute batches to individual workers indices = self.distribute_batches(indices) return iter(indices)
def __iter__(self): gnmt_print(key=mlperf_log.INPUT_ORDER) # deterministically shuffle based on epoch g = torch.Generator() seed = self.seeds[self.epoch] logging.info(f'Sampler for epoch {self.epoch} uses seed {seed}') g.manual_seed(seed) # generate permutation indices = torch.randperm(self.data_len, generator=g) # make indices evenly divisible by (batch_size * world_size) indices = indices[:self.num_samples] # splits the dataset into chunks of 'batches_in_shard' global batches # each, sorts by (src + tgt) sequence length within each chunk, # reshuffles all global batches if self.bucketing: batches_in_shard = 80 shard_size = self.global_batch_size * batches_in_shard gnmt_print(key=mlperf_log.INPUT_SHARD, value=shard_size) nshards = (self.num_samples + shard_size - 1) // shard_size lengths = self.dataset.lengths[indices] shards = [indices[i * shard_size:(i+1) * shard_size] for i in range(nshards)] len_shards = [lengths[i * shard_size:(i+1) * shard_size] for i in range(nshards)] indices = [] for len_shard in len_shards: _, ind = len_shard.sort() indices.append(ind) output = tuple(shard[idx] for shard, idx in zip(shards, indices)) indices = torch.cat(output) # global reshuffle indices = indices.view(-1, self.global_batch_size) order = torch.randperm(indices.shape[0], generator=g) indices = indices[order, :] indices = indices.view(-1) assert len(indices) == self.num_samples # build indices for each individual worker # consecutive ranks are getting consecutive batches, # default pytorch DistributedSampler assigns strided batches # with offset = length / world_size indices = indices.view(-1, self.batch_size) indices = indices[self.rank::self.world_size].contiguous() indices = indices.view(-1) indices = indices.tolist() assert len(indices) == self.num_samples // self.world_size return iter(indices)
def __iter__(self): gnmt_print(key=mlperf_log.INPUT_ORDER, sync=False) rng = self.init_rng() # generate permutation indices = torch.randperm(self.data_len, generator=rng) # make indices evenly divisible by (batch_size * world_size) indices = indices[:self.num_samples] # assign batches to workers indices = self.distribute_batches(indices) return iter(indices)
def __init__(self, model, beam_size=5, max_seq_len=100, cuda=False, len_norm_factor=0.6, len_norm_const=5, cov_penalty_factor=0.1): """ Constructor for the SequenceGenerator. Beam search decoding supports coverage penalty and length normalization. For details, refer to Section 7 of the GNMT paper (https://arxiv.org/pdf/1609.08144.pdf). :param model: model which implements generate method :param beam_size: decoder beam size :param max_seq_len: maximum decoder sequence length :param cuda: whether to use cuda :param len_norm_factor: length normalization factor :param len_norm_const: length normalization constant :param cov_penalty_factor: coverage penalty factor """ self.model = model self.cuda = cuda self.beam_size = beam_size self.max_seq_len = max_seq_len self.len_norm_factor = len_norm_factor self.len_norm_const = len_norm_const self.cov_penalty_factor = cov_penalty_factor self.batch_first = self.model.batch_first gnmt_print(key=mlperf_log.EVAL_HP_BEAM_SIZE, value=self.beam_size, sync=False) gnmt_print(key=mlperf_log.EVAL_HP_MAX_SEQ_LEN, value=self.max_seq_len, sync=False) gnmt_print(key=mlperf_log.EVAL_HP_LEN_NORM_CONST, value=self.len_norm_const, sync=False) gnmt_print(key=mlperf_log.EVAL_HP_LEN_NORM_FACTOR, value=self.len_norm_factor, sync=False) gnmt_print(key=mlperf_log.EVAL_HP_COV_PENALTY_FACTOR, value=self.cov_penalty_factor, sync=False)
def __init__(self, dataset, batch_size, pad, world_size=None, rank=None): """ Constructor for the StaticDistributedSampler. :param dataset: dataset :param batch_size: local batch size :param pad: if True: pads dataset to a multiple of global_batch_size samples :param world_size: number of distributed workers :param rank: rank of the current process """ if world_size is None: world_size = get_world_size() if rank is None: rank = get_rank() self.world_size = world_size global_batch_size = batch_size * world_size gnmt_print(key=mlperf_log.INPUT_ORDER, sync=False) data_len = len(dataset) num_samples = (data_len + global_batch_size - 1) \ // global_batch_size * global_batch_size self.num_samples = num_samples indices = list(range(data_len)) if pad: # pad dataset to a multiple of global_batch_size samples, uses # sample with idx 0 as pad indices += [0] * (num_samples - len(indices)) else: # temporary pad to a multiple of global batch size, pads with "-1" # which is later removed from the list of indices indices += [-1] * (num_samples - len(indices)) indices = torch.tensor(indices) indices = indices.view(-1, batch_size) indices = indices[rank::world_size].contiguous() indices = indices.view(-1) # remove temporary pad indices = indices[indices != -1] indices = indices.tolist() self.indices = indices
def build_criterion(vocab_size, padding_idx, smoothing): if smoothing == 0.: logging.info(f'Building CrossEntropyLoss') loss_weight = torch.ones(vocab_size) loss_weight[padding_idx] = 0 criterion = nn.CrossEntropyLoss(weight=loss_weight, size_average=False) gnmt_print(key=mlperf_log.MODEL_HP_LOSS_FN, value='Cross Entropy') else: logging.info(f'Building LabelSmoothingLoss (smoothing: {smoothing})') criterion = LabelSmoothing(padding_idx, smoothing) gnmt_print(key=mlperf_log.MODEL_HP_LOSS_FN, value='Cross Entropy with label smoothing') gnmt_print(key=mlperf_log.MODEL_HP_LOSS_SMOOTHING, value=smoothing) return criterion
def __init__(self, vocab_size, hidden_size=1024, num_layers=4, dropout=0.2, batch_first=False, share_embedding=True): """ Constructor for the GNMT v2 model. :param vocab_size: size of vocabulary (number of tokens) :param hidden_size: internal hidden size of the model :param num_layers: number of layers, applies to both encoder and decoder :param dropout: probability of dropout (in encoder and decoder) :param batch_first: if True the model uses (batch,seq,feature) tensors, if false the model uses (seq, batch, feature) :param share_embedding: if True embeddings are shared between encoder and decoder """ super(GNMT, self).__init__(batch_first=batch_first) gnmt_print(key=mlperf_log.MODEL_HP_NUM_LAYERS, value=num_layers, sync=False) gnmt_print(key=mlperf_log.MODEL_HP_HIDDEN_SIZE, value=hidden_size, sync=False) gnmt_print(key=mlperf_log.MODEL_HP_DROPOUT, value=dropout, sync=False) if share_embedding: embedder = nn.Embedding(vocab_size, hidden_size, padding_idx=config.PAD) nn.init.uniform_(embedder.weight.data, -0.1, 0.1) else: embedder = None self.encoder = ResidualRecurrentEncoder(vocab_size, hidden_size, num_layers, dropout, batch_first, embedder) self.decoder = ResidualRecurrentDecoder(vocab_size, hidden_size, num_layers, dropout, batch_first, embedder)
def __init__(self, vocab_size, hidden_size=512, num_layers=8, bias=True, dropout=0.2, batch_first=False, math='fp32', share_embedding=False): """ Constructor for the GNMT v2 model. :param vocab_size: size of vocabulary (number of tokens) :param hidden_size: internal hidden size of the model :param num_layers: number of layers, applies to both encoder and decoder :param bias: globally enables or disables bias in encoder and decoder :param dropout: probability of dropout (in encoder and decoder) :param batch_first: if True the model uses (batch,seq,feature) tensors, if false the model uses (seq, batch, feature) :param math: arithmetic type, 'fp32' or 'fp16' :param share_embedding: if True embeddings are shared between encoder and decoder """ super(GNMT, self).__init__(batch_first=batch_first) gnmt_print(key=mlperf_log.MODEL_HP_NUM_LAYERS, value=num_layers) gnmt_print(key=mlperf_log.MODEL_HP_HIDDEN_SIZE, value=hidden_size) gnmt_print(key=mlperf_log.MODEL_HP_DROPOUT, value=dropout) if share_embedding: embedder = nn.Embedding(vocab_size, hidden_size, padding_idx=config.PAD) else: embedder = None self.encoder = ResidualRecurrentEncoder(vocab_size, hidden_size, num_layers, bias, dropout, batch_first, embedder) self.decoder = ResidualRecurrentDecoder(vocab_size, hidden_size, num_layers, bias, dropout, batch_first, math, embedder)
def __init__(self, optimizer, iterations, warmup_steps=0, remain_steps=1.0, decay_interval=None, decay_steps=4, decay_factor=0.5, last_epoch=-1): """ Constructor of WarmupMultiStepLR. Parameters: warmup_steps, remain_steps and decay_interval accept both integers and floats as an input. Integer input is interpreted as absolute index of iteration, float input is interpreted as a fraction of total training iterations (epochs * steps_per_epoch). If decay_interval is None then the decay will happen at regulary spaced intervals ('decay_steps' decays between iteration indices 'remain_steps' and 'iterations'). :param optimizer: instance of optimizer :param iterations: total number of training iterations :param warmup_steps: number of warmup iterations :param remain_steps: start decay at 'remain_steps' iteration :param decay_interval: interval between LR decay steps :param decay_steps: max number of decay steps :param decay_factor: decay factor :param last_epoch: the index of last iteration """ # iterations before learning rate reaches base LR self.warmup_steps = perhaps_convert_float(warmup_steps, iterations) logging.info(f'Scheduler warmup steps: {self.warmup_steps}') # iteration at which decay starts self.remain_steps = perhaps_convert_float(remain_steps, iterations) logging.info(f'Scheduler remain steps: {self.remain_steps}') # number of steps between each decay if decay_interval is None: # decay at regulary spaced intervals decay_iterations = iterations - self.remain_steps self.decay_interval = decay_iterations // (decay_steps) self.decay_interval = max(self.decay_interval, 1) else: self.decay_interval = perhaps_convert_float( decay_interval, iterations) logging.info(f'Scheduler decay interval: {self.decay_interval}') # multiplicative decay factor self.decay_factor = decay_factor logging.info(f'Scheduler decay factor: {self.decay_factor}') # max number of decay steps self.decay_steps = decay_steps logging.info(f'Scheduler max decay steps: {self.decay_steps}') if self.warmup_steps > self.remain_steps: logging.warn(f'warmup_steps should not be larger than ' f'remain_steps, setting warmup_steps=remain_steps') self.warmup_steps = self.remain_steps gnmt_print(key=mlperf_log.OPT_LR_WARMUP_STEPS, value=self.warmup_steps, sync=False) super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch)
def main(): """ Launches data-parallel multi-gpu training. """ mlperf_log.ROOT_DIR_GNMT = os.path.dirname(os.path.abspath(__file__)) mlperf_log.LOGGER.propagate = False args = parse_args() if args.cuda: torch.cuda.set_device(args.local_rank) device = torch.device('cuda') else: device = torch.device('cpu') # initialize distributed backend distributed = False if 'WORLD_SIZE' in os.environ: distributed = int(os.environ['WORLD_SIZE']) > 1 if distributed: assert args.cuda '''Initialize distributed communication''' torch.distributed.init_process_group(backend='nccl', init_method='env://') assert torch.distributed.is_initialized() gnmt_print(key=mlperf_log.RUN_START) args.rank = get_rank() if not args.cudnn: torch.backends.cudnn.enabled = False # create directory for results save_path = os.path.join(args.results_dir, args.save) args.save_path = save_path os.makedirs(save_path, exist_ok=True) # setup logging log_filename = f'log_gpu_{args.rank}.log' setup_logging(os.path.join(save_path, log_filename)) logging.info(f'Saving results to: {save_path}') logging.info(f'Run arguments: {args}') # setup L2 promotion if args.cuda: l2_promote() gnmt_print(key=mlperf_log.RUN_SET_RANDOM_SEED) # https://github.com/mlperf/policies/issues/120#issuecomment-431111348 if args.seed is None: # random master seed, random.SystemRandom() uses /dev/urandom on Unix master_seed = random.SystemRandom().randint(0, 2**32 - 1) if get_rank() == 0: # master seed is reported only from rank=0 worker, it's to avoid # confusion, seeds from rank=0 are later broadcasted to other # workers logging.info(f'Using random master seed: {master_seed}') else: # master seed was specified from command line master_seed = args.seed logging.info(f'Using master seed from command line: {master_seed}') # initialize seeding RNG seeding_rng = random.Random(master_seed) # generate worker seeds, one seed for every distributed worker worker_seeds = generate_seeds(seeding_rng, get_world_size()) # generate seeds for data shuffling, one seed for every epoch shuffling_seeds = generate_seeds(seeding_rng, args.epochs) # broadcast seeds from rank=0 to other workers worker_seeds = broadcast_seeds(worker_seeds, device) shuffling_seeds = broadcast_seeds(shuffling_seeds, device) # set worker seed worker_seed = worker_seeds[args.rank] logging.info(f'Worker {args.rank} is using worker seed: {worker_seed}') torch.manual_seed(worker_seed) # build tokenizer tokenizer = Tokenizer(os.path.join(args.dataset_dir, config.VOCAB_FNAME)) # build datasets gnmt_print(key=mlperf_log.PREPROC_TOKENIZE_TRAINING) gnmt_print(key=mlperf_log.TRAIN_HP_MAX_SEQ_LEN, value=args.max_length_train) train_data = LazyParallelDataset( src_fname=os.path.join(args.dataset_dir, config.SRC_TRAIN_FNAME), tgt_fname=os.path.join(args.dataset_dir, config.TGT_TRAIN_FNAME), tokenizer=tokenizer, min_len=args.min_length_train, max_len=args.max_length_train, sort=False, max_size=args.max_size) gnmt_print(key=mlperf_log.PREPROC_NUM_TRAIN_EXAMPLES, value=len(train_data)) val_data = ParallelDataset(src_fname=os.path.join(args.dataset_dir, config.SRC_VAL_FNAME), tgt_fname=os.path.join(args.dataset_dir, config.TGT_VAL_FNAME), tokenizer=tokenizer, min_len=args.min_length_val, max_len=args.max_length_val, sort=True) gnmt_print(key=mlperf_log.PREPROC_TOKENIZE_EVAL) test_data = TextDataset(src_fname=os.path.join(args.dataset_dir, config.SRC_TEST_FNAME), tokenizer=tokenizer, min_len=args.min_length_test, max_len=args.max_length_test, sort=False) gnmt_print(key=mlperf_log.PREPROC_NUM_EVAL_EXAMPLES, value=len(test_data)) vocab_size = tokenizer.vocab_size # size of the vocabulary has been padded to a multiple of 8 gnmt_print(key=mlperf_log.PREPROC_VOCAB_SIZE, value=vocab_size) # build GNMT model model_config = dict(vocab_size=vocab_size, math=args.math, **literal_eval(args.model_config)) model = GNMT(**model_config) logging.info(model) batch_first = model.batch_first # define loss function (criterion) and optimizer criterion = build_criterion(vocab_size, config.PAD, args.smoothing) opt_config = literal_eval(args.optimization_config) scheduler_config = literal_eval(args.scheduler_config) logging.info(f'Training optimizer: {opt_config}') logging.info(f'Training LR Schedule: {scheduler_config}') num_parameters = sum([l.nelement() for l in model.parameters()]) logging.info(f'Number of parameters: {num_parameters}') # get data loaders train_loader = train_data.get_loader(batch_size=args.batch_size, seeds=shuffling_seeds, batch_first=batch_first, shuffle=True, bucketing=args.bucketing, num_workers=args.train_loader_workers) gnmt_print(key=mlperf_log.INPUT_BATCH_SIZE, value=args.batch_size * get_world_size()) gnmt_print(key=mlperf_log.INPUT_SIZE, value=train_loader.sampler.num_samples) val_loader = val_data.get_loader(batch_size=args.val_batch_size, batch_first=batch_first, shuffle=False, num_workers=args.val_loader_workers) test_loader = test_data.get_loader(batch_size=args.test_batch_size, batch_first=batch_first, shuffle=False, pad=True, num_workers=args.test_loader_workers) gnmt_print(key=mlperf_log.EVAL_SIZE, value=len(test_loader.dataset)) translator = Translator(model=model, tokenizer=tokenizer, loader=test_loader, beam_size=args.beam_size, max_seq_len=args.max_length_test, len_norm_factor=args.len_norm_factor, len_norm_const=args.len_norm_const, cov_penalty_factor=args.cov_penalty_factor, cuda=args.cuda, print_freq=args.print_freq, dataset_dir=args.dataset_dir, target_bleu=args.target_bleu, save_path=args.save_path) # create trainer trainer_options = dict( criterion=criterion, grad_clip=args.grad_clip, save_path=save_path, save_freq=args.save_freq, save_info={ 'config': args, 'tokenizer': tokenizer.get_state() }, opt_config=opt_config, scheduler_config=scheduler_config, batch_first=batch_first, keep_checkpoints=args.keep_checkpoints, math=args.math, print_freq=args.print_freq, cuda=args.cuda, distributed=distributed, distributed_overlap_allreduce=args.enable_apex_allreduce_overlap, distributed_overlap_allreduce_messagesize=args.apex_message_size, intra_epoch_eval=args.intra_epoch_eval, translator=translator, arch=args.arch) trainer_options['model'] = model trainer = trainers.Seq2SeqTrainer(**trainer_options) # optionally resume from a checkpoint if args.resume: checkpoint_file = args.resume if os.path.isdir(checkpoint_file): checkpoint_file = os.path.join(checkpoint_file, 'model_best.pth') if os.path.isfile(checkpoint_file): trainer.load(checkpoint_file) else: logging.error(f'No checkpoint found at {args.resume}') # training loop # best_loss = float('inf') gnmt_print(key=mlperf_log.TRAIN_LOOP) for epoch in range(1): logging.info(f'Starting epoch {epoch}') gnmt_print(key=mlperf_log.TRAIN_EPOCH, value=epoch) if distributed: train_loader.sampler.set_epoch(epoch) trainer.epoch = epoch train_loss, train_perf = trainer.optimize(train_loader) logging.info(f'Finished epoch {epoch}') # Save the checkpoint at the end of the training loop, after the RUN_STOP # tag # https://github.com/mlperf/policies/issues/55#issuecomment-428335773 if not args.disable_eval: gnmt_print(key=mlperf_log.TRAIN_CHECKPOINT) if get_rank() == 0: trainer.save(save_all=args.save_all, is_best=True) gnmt_print(key=mlperf_log.RUN_FINAL)
def main(): """ Launches data-parallel multi-gpu training. """ mlperf_log.ROOT_DIR_GNMT = os.path.dirname(os.path.abspath(__file__)) mlperf_log.LOGGER.propagate = False args = parse_args() device = utils.set_device(args.cuda, args.local_rank) distributed = utils.init_distributed(args.cuda) gnmt_print(key=mlperf_log.RUN_START, sync=True) args.rank = utils.get_rank() if not args.cudnn: torch.backends.cudnn.enabled = False # create directory for results save_path = os.path.join(args.results_dir, args.save) args.save_path = save_path os.makedirs(save_path, exist_ok=True) # setup logging log_filename = f'log_rank_{utils.get_rank()}.log' utils.setup_logging(os.path.join(save_path, log_filename)) if args.env: utils.log_env_info() logging.info(f'Saving results to: {save_path}') logging.info(f'Run arguments: {args}') # automatically set train_iter_size based on train_global_batch_size, # world_size and per-worker train_batch_size if args.train_global_batch_size is not None: global_bs = args.train_global_batch_size bs = args.train_batch_size world_size = utils.get_world_size() assert global_bs % (bs * world_size) == 0 args.train_iter_size = global_bs // (bs * world_size) logging.info(f'Global batch size was set in the config, ' f'Setting train_iter_size to {args.train_iter_size}') worker_seeds, shuffling_seeds = utils.setup_seeds(args.seed, args.epochs, device) worker_seed = worker_seeds[args.rank] logging.info(f'Worker {args.rank} is using worker seed: {worker_seed}') torch.manual_seed(worker_seed) # build tokenizer pad_vocab = utils.pad_vocabulary(args.math) tokenizer = Tokenizer(os.path.join(args.dataset_dir, config.VOCAB_FNAME), pad_vocab) # build datasets gnmt_print(key=mlperf_log.PREPROC_TOKENIZE_TRAINING, sync=False) gnmt_print(key=mlperf_log.TRAIN_HP_MAX_SEQ_LEN, value=args.max_length_train, sync=False) train_data = LazyParallelDataset( src_fname=os.path.join(args.dataset_dir, config.SRC_TRAIN_FNAME), tgt_fname=os.path.join(args.dataset_dir, config.TGT_TRAIN_FNAME), tokenizer=tokenizer, min_len=args.min_length_train, max_len=args.max_length_train, sort=False, max_size=args.max_size) gnmt_print(key=mlperf_log.PREPROC_NUM_TRAIN_EXAMPLES, value=len(train_data), sync=False) val_data = ParallelDataset(src_fname=os.path.join(args.dataset_dir, config.SRC_VAL_FNAME), tgt_fname=os.path.join(args.dataset_dir, config.TGT_VAL_FNAME), tokenizer=tokenizer, min_len=args.min_length_val, max_len=args.max_length_val, sort=True) gnmt_print(key=mlperf_log.PREPROC_TOKENIZE_EVAL, sync=False) test_data = TextDataset(src_fname=os.path.join(args.dataset_dir, config.SRC_TEST_FNAME), tokenizer=tokenizer, min_len=args.min_length_test, max_len=args.max_length_test, sort=True) gnmt_print(key=mlperf_log.PREPROC_NUM_EVAL_EXAMPLES, value=len(test_data), sync=False) vocab_size = tokenizer.vocab_size gnmt_print(key=mlperf_log.PREPROC_VOCAB_SIZE, value=vocab_size, sync=False) # build GNMT model model_config = { 'hidden_size': args.hidden_size, 'num_layers': args.num_layers, 'dropout': args.dropout, 'batch_first': False, 'share_embedding': args.share_embedding } model = GNMT(vocab_size=vocab_size, **model_config) logging.info(model) batch_first = model.batch_first # define loss function (criterion) and optimizer criterion = build_criterion(vocab_size, config.PAD, args.smoothing) opt_config = {'optimizer': args.optimizer, 'lr': args.lr} opt_config.update(literal_eval(args.optimizer_extra)) logging.info(f'Training optimizer config: {opt_config}') scheduler_config = { 'warmup_steps': args.warmup_steps, 'remain_steps': args.remain_steps, 'decay_interval': args.decay_interval, 'decay_steps': args.decay_steps, 'decay_factor': args.decay_factor } logging.info(f'Training LR schedule config: {scheduler_config}') num_parameters = sum([l.nelement() for l in model.parameters()]) logging.info(f'Number of parameters: {num_parameters}') batching_opt = { 'shard_size': args.shard_size, 'num_buckets': args.num_buckets } # get data loaders train_loader = train_data.get_loader(batch_size=args.train_batch_size, seeds=shuffling_seeds, batch_first=batch_first, shuffle=True, batching=args.batching, batching_opt=batching_opt, num_workers=args.train_loader_workers) gnmt_print(key=mlperf_log.INPUT_BATCH_SIZE, value=args.train_batch_size * utils.get_world_size(), sync=False) gnmt_print(key=mlperf_log.INPUT_SIZE, value=train_loader.sampler.num_samples, sync=False) val_loader = val_data.get_loader(batch_size=args.val_batch_size, batch_first=batch_first, shuffle=False, num_workers=args.val_loader_workers) test_loader = test_data.get_loader(batch_size=args.test_batch_size, batch_first=batch_first, shuffle=False, pad=True, num_workers=args.test_loader_workers) gnmt_print(key=mlperf_log.EVAL_SIZE, value=len(test_loader.dataset), sync=False) translator = Translator(model=model, tokenizer=tokenizer, loader=test_loader, beam_size=args.beam_size, max_seq_len=args.max_length_test, len_norm_factor=args.len_norm_factor, len_norm_const=args.len_norm_const, cov_penalty_factor=args.cov_penalty_factor, cuda=args.cuda, print_freq=args.print_freq, dataset_dir=args.dataset_dir, target_bleu=args.target_bleu, save_path=args.save_path) # create trainer total_train_iters = len(train_loader) // args.train_iter_size * args.epochs save_info = { 'model_config': model_config, 'config': args, 'tokenizer': tokenizer.get_state() } trainer_options = dict(criterion=criterion, grad_clip=args.grad_clip, iter_size=args.train_iter_size, save_path=save_path, save_freq=args.save_freq, save_info=save_info, opt_config=opt_config, scheduler_config=scheduler_config, train_iterations=total_train_iters, batch_first=batch_first, keep_checkpoints=args.keep_checkpoints, math=args.math, print_freq=args.print_freq, cuda=args.cuda, distributed=distributed, intra_epoch_eval=args.intra_epoch_eval, translator=translator) trainer_options['model'] = model trainer = trainers.Seq2SeqTrainer(**trainer_options) # optionally resume from a checkpoint if args.resume: checkpoint_file = args.resume if os.path.isdir(checkpoint_file): checkpoint_file = os.path.join(checkpoint_file, 'model_best.pth') if os.path.isfile(checkpoint_file): trainer.load(checkpoint_file) else: logging.error(f'No checkpoint found at {args.resume}') # training loop best_loss = float('inf') break_training = False test_bleu = None gnmt_print(key=mlperf_log.TRAIN_LOOP, sync=True) for epoch in range(args.start_epoch, args.epochs): logging.info(f'Starting epoch {epoch}') gnmt_print(key=mlperf_log.TRAIN_EPOCH, value=epoch, sync=True) train_loader.sampler.set_epoch(epoch) trainer.epoch = epoch train_loss, train_perf = trainer.optimize(train_loader) # evaluate on validation set if args.eval: logging.info(f'Running validation on dev set') val_loss, val_perf = trainer.evaluate(val_loader) # remember best prec@1 and save checkpoint gnmt_print(key=mlperf_log.TRAIN_CHECKPOINT, sync=False) if args.rank == 0: is_best = val_loss < best_loss best_loss = min(val_loss, best_loss) trainer.save(save_all=args.save_all, is_best=is_best) if args.eval: gnmt_print(key=mlperf_log.EVAL_START, value=epoch, sync=True) test_bleu, break_training = translator.run(calc_bleu=True, epoch=epoch) gnmt_print(key=mlperf_log.EVAL_ACCURACY, value={ "epoch": epoch, "value": round(test_bleu, 2) }, sync=False) gnmt_print(key=mlperf_log.EVAL_TARGET, value=args.target_bleu, sync=False) gnmt_print(key=mlperf_log.EVAL_STOP, sync=True) acc_log = [] acc_log += [f'Summary: Epoch: {epoch}'] acc_log += [f'Training Loss: {train_loss:.4f}'] if args.eval: acc_log += [f'Validation Loss: {val_loss:.4f}'] acc_log += [f'Test BLEU: {test_bleu:.2f}'] perf_log = [] perf_log += [f'Performance: Epoch: {epoch}'] perf_log += [f'Training: {train_perf:.0f} Tok/s'] if args.eval: perf_log += [f'Validation: {val_perf:.0f} Tok/s'] if args.rank == 0: logging.info('\t'.join(acc_log)) logging.info('\t'.join(perf_log)) logging.info(f'Finished epoch {epoch}') if break_training: break gnmt_print(key=mlperf_log.RUN_STOP, value={"success": bool(break_training)}, sync=True) gnmt_print(key=mlperf_log.RUN_FINAL, sync=False)
def __init__(self, model, criterion, opt_config, scheduler_config, print_freq=10, save_freq=1000, grad_clip=float('inf'), batch_first=False, save_info={}, save_path='.', checkpoint_filename='checkpoint%s.pth', keep_checkpoints=5, math='fp32', cuda=True, distributed=False, distributed_overlap_allreduce=False, distributed_overlap_allreduce_messagesize=1e7, intra_epoch_eval=0, translator=None, verbose=False, arch="gnmt"): super(Seq2SeqTrainer, self).__init__() self.model = model self.criterion = criterion self.epoch = 0 self.save_info = save_info self.save_path = save_path self.save_freq = save_freq self.save_counter = 0 self.checkpoint_filename = checkpoint_filename self.checkpoint_counter = cycle(range(keep_checkpoints)) self.opt_config = opt_config self.cuda = cuda self.distributed = distributed self.print_freq = print_freq self.batch_first = batch_first self.verbose = verbose self.loss = None self.translator = translator self.intra_epoch_eval = intra_epoch_eval self.arch = arch self.retain_allreduce_buffers = True self.gradient_average = False if cuda: self.model = self.model.cuda() self.criterion = self.criterion.cuda() if math == 'fp16': self.model = self.model.half() if distributed: # self.model = apex.parallel.DistributedDataParallel(self.model, message_size=10000000, delay_allreduce=True) self.model = apex.parallel.DistributedDataParallel( self.model, message_size=distributed_overlap_allreduce_messagesize, delay_allreduce=(not distributed_overlap_allreduce), retain_allreduce_buffers=self.retain_allreduce_buffers, gradient_average=self.gradient_average) self.fp_optimizer = Fp16Optimizer(self.model, grad_clip) params = [self.fp_optimizer.fp32_params] elif math == 'fp32': if distributed: # self.model = apex.parallel.DistributedDataParallel(self.model, message_size=10000000, delay_allreduce=True) self.model = apex.parallel.DistributedDataParallel( self.model, message_size=distributed_overlap_allreduce_messagesize, delay_allreduce=(not distributed_overlap_allreduce)) self.fp_optimizer = Fp32Optimizer(self.model, grad_clip) params = self.model.parameters() opt_name = opt_config.pop('optimizer') if opt_name == 'FusedAdam': self.optimizer = apex.optimizers.FusedAdam(params, **opt_config) else: self.optimizer = torch.optim.__dict__[opt_name](params, **opt_config) gnmt_print(key=mlperf_log.OPT_NAME, value=mlperf_log.ADAM) gnmt_print(key=mlperf_log.OPT_LR, value=opt_config['lr']) gnmt_print(key=mlperf_log.OPT_HP_ADAM_BETA1, value=self.optimizer.defaults['betas'][0]) gnmt_print(key=mlperf_log.OPT_HP_ADAM_BETA2, value=self.optimizer.defaults['betas'][1]) gnmt_print(key=mlperf_log.OPT_HP_ADAM_EPSILON, value=self.optimizer.defaults['eps']) self.scheduler = WarmupMultiStepLR( self.optimizer, lr_method=scheduler_config["lr_method"], warmup_iters=scheduler_config["warmup_iters"], remain_steps=scheduler_config["remain_steps"], decay_steps=scheduler_config["decay_steps"]) logging.info(f'Using optimizer: {self.optimizer}')
def __init__(self, model, criterion, opt_config, scheduler_config, print_freq=10, save_freq=1000, grad_clip=float('inf'), batch_first=False, save_info={}, save_path='.', train_iterations=0, checkpoint_filename='checkpoint%s.pth', keep_checkpoints=5, math='fp32', cuda=True, distributed=False, intra_epoch_eval=0, iter_size=1, translator=None, verbose=False): """ Constructor for the Seq2SeqTrainer. :param model: model to train :param criterion: criterion (loss function) :param opt_config: dictionary with options for the optimizer :param scheduler_config: dictionary with options for the learning rate scheduler :param print_freq: prints short summary every 'print_freq' iterations :param save_freq: saves checkpoint every 'save_freq' iterations :param grad_clip: coefficient for gradient clipping :param batch_first: if True the model uses (batch,seq,feature) tensors, if false the model uses (seq, batch, feature) :param save_info: dict with additional state stored in each checkpoint :param save_path: path to the directiory for checkpoints :param train_iterations: total number of training iterations to execute :param checkpoint_filename: name of files with checkpoints :param keep_checkpoints: max number of checkpoints to keep :param math: arithmetic type :param cuda: if True use cuda, if False train on cpu :param distributed: if True run distributed training :param intra_epoch_eval: number of additional eval runs within each training epoch :param iter_size: number of iterations between weight updates :param translator: instance of Translator, runs inference on test set :param verbose: enables verbose logging """ super(Seq2SeqTrainer, self).__init__() self.model = model self.criterion = criterion self.epoch = 0 self.save_info = save_info self.save_path = save_path self.save_freq = save_freq self.save_counter = 0 self.checkpoint_filename = checkpoint_filename self.checkpoint_counter = cycle(range(keep_checkpoints)) self.opt_config = opt_config self.cuda = cuda self.distributed = distributed self.print_freq = print_freq self.batch_first = batch_first self.verbose = verbose self.loss = None self.translator = translator self.intra_epoch_eval = intra_epoch_eval self.iter_size = iter_size if cuda: self.model = self.model.cuda() self.criterion = self.criterion.cuda() if math == 'fp16': self.model = self.model.half() if distributed: self.model = DDP(self.model) if math == 'fp16': self.fp_optimizer = Fp16Optimizer(self.model, grad_clip) params = self.fp_optimizer.fp32_params elif math == 'fp32': self.fp_optimizer = Fp32Optimizer(self.model, grad_clip) params = self.model.parameters() opt_name = opt_config.pop('optimizer') self.optimizer = torch.optim.__dict__[opt_name](params, **opt_config) logging.info(f'Using optimizer: {self.optimizer}') gnmt_print(key=mlperf_log.OPT_NAME, value=mlperf_log.ADAM, sync=False) gnmt_print(key=mlperf_log.OPT_LR, value=opt_config['lr'], sync=False) gnmt_print(key=mlperf_log.OPT_HP_ADAM_BETA1, value=self.optimizer.defaults['betas'][0], sync=False) gnmt_print(key=mlperf_log.OPT_HP_ADAM_BETA2, value=self.optimizer.defaults['betas'][1], sync=False) gnmt_print(key=mlperf_log.OPT_HP_ADAM_EPSILON, value=self.optimizer.defaults['eps'], sync=False) self.scheduler = WarmupMultiStepLR(self.optimizer, train_iterations, **scheduler_config)