def save_checkpoint(self, filename, extra_state): """Save all training state in a checkpoint file.""" if distributed_utils.is_master(self.args): # only save one checkpoint extra_state['train_meters'] = self.meters checkpoint_utils.save_state( filename, self.args, self.get_model().state_dict(), None, self.optimizer, self.lr_scheduler, self.get_num_updates(), self._optim_history, extra_state, )
def evaluate(args, model, va_loader): """Evaluate on validation data.""" # Keep non-master processes waiting here if not is_master(args): accuracy = torch.zeros([1]).cuda() torch.distributed.barrier() # Only master perform evaluation if is_master(args): num_correct, num_example = 0, 0 num_tp, num_fp, num_tn, num_fn = 0, 0, 0, 0 model.eval() with torch.no_grad(): for sent, seg_id, label in va_loader.get_iter(shuffle=False): _, ret_dict = model(sent, seg_id=seg_id, cls_target=label) cls_corr = ret_dict["cls_corr"] num_correct += cls_corr num_example += len(sent) tp, fp, tn, fn = confusion_matrix(ret_dict["cls_pred"], label) num_tp = num_tp + tp num_fp = num_fp + fp num_tn = num_tn + tn num_fn = num_fn + fn model.train() if args.dataset in ["CoLA"]: accuracy = _compute_metric_based_on_keys("corr", num_tp.item(), num_fp.item(), num_tn.item(), num_fn.item()) accuracy = torch.FloatTensor([accuracy]).cuda() else: accuracy = num_correct / num_example if args.distributed: torch.distributed.barrier() # sync accuracy if args.distributed: torch.distributed.all_reduce(accuracy, op=torch.distributed.ReduceOp.SUM) return accuracy.item()
def setup_special_ids(args, tokenizer): """Set up the id of special tokens.""" special_symbols_mapping = collections.OrderedDict([("<unk>", "unk_id"), ("<s>", "bos_id"), ("</s>", "eos_id"), ("<cls>", "cls_id"), ("<sep>", "sep_id"), ("<pad>", "pad_id"), ("<mask>", "mask_id"), ("<eod>", "eod_id"), ("<eop>", "eop_id")]) args.vocab_size = tokenizer.get_vocab_size() if is_master(args): print("Set vocab_size: {}.".format(args.vocab_size)) for sym, sym_id_str in special_symbols_mapping.items(): try: sym_id = tokenizer.get_token_id(sym) setattr(args, sym_id_str, sym_id) if is_master(args): print("Set {} to {}.".format(sym_id_str, sym_id)) except KeyError: if is_master(args): print("Skip {}: not found in tokenizer's vocab.".format(sym))
def predict(args, model, loader, out_path, rev_label_dict): """Make prediction and write to file. This should only be called by master.""" # Only master perform prediction if is_master(args): model.eval() with open(out_path, "w") as fo: with torch.no_grad(): for sent, seg_id, label in loader.get_iter(shuffle=False): _, ret_dict = model(sent, seg_id=seg_id, cls_target=label) cls_pred = ret_dict["cls_pred"] for i in range(cls_pred.size(0)): label = rev_label_dict[cls_pred[i].item()] fo.write("{}\n".format(label)) model.train()
def main(args): args = options.set_default_args(args) if args.ddp_backend == 'apex': from apex.parallel import DistributedDataParallel as DDP else: from torch.nn.parallel import DistributedDataParallel as DDP ############################################################################ # Random seed ############################################################################ np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) ############################################################################ # Experiment & Logging ############################################################################ if is_master(args): if args.resume: # rank-0 device creates experiment dir and log to the file logging = get_logger(os.path.join(args.expname, 'log.txt'), log_=not args.debug) else: # rank-0 device creates experiment dir and log to the file logging = create_exp_dir(args.expname, debug=args.debug) else: # other devices only log to console (print) but not the file logging = get_logger(log_path=None, log_=False) args.model_path = os.path.join(args.expname, 'model.pt') args.var_path = os.path.join(args.expname, 'var.pt') ############################################################################ # Load data ############################################################################ logging('Loading data..') tr_data, va_data = options.load_data(args) train_step = 0 best_eval_ll = -float('inf') if args.resume: logging('Resuming from {}...'.format(args.resume)) model, opt = torch.load(args.model_path, map_location='cpu') model = model.to(args.device) for state in opt.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.to(args.device) best_eval_ll, train_step = torch.load(args.var_path) else: logging('Building model..') if args.model_name in ['srnn', 'srnn_zforce', 'srnn_hier']: model = eval(args.model_name).Model(args.n_mix, args.d_data, args.d_emb, args.d_mlp, args.d_rnn, args.d_lat, dropout=args.dropout, n_layer=args.n_layer) elif args.model_name in ['rnn', 'rnn_hier']: model = eval(args.model_name).Model(args.n_mix, args.d_data, args.d_emb, args.d_rnn, dropout=args.dropout, n_layer=args.n_layer) else: raise ValueError('unsupported model type {}'.format( args.model_name)) model = model.to(args.device) # create new optimizer opt = torch.optim.Adam(model.parameters(), lr=args.lr) if not args.test_only: # criterion params and model params crit_params, model_params = [], [] for n, p in model.named_parameters(): if 'crit' in n: crit_params.append(p) else: model_params.append(p) ############################################################################ # Distributed Data Parallel ############################################################################ if args.distributed: if args.ddp_backend == 'apex': torch.cuda.set_device(args.distributed_rank) para_model = DDP(model) else: para_model = DDP(model, device_ids=[args.device_id], output_device=args.device_id) else: para_model = model ############################################################################ # Log args ############################################################################ args.n_crit_param = sum([p.nelement() for p in crit_params]) args.n_model_param = sum([p.nelement() for p in model_params]) args.n_param = args.n_crit_param + args.n_model_param if is_master(args): logging('=' * 100) for k, v in args.__dict__.items(): logging(' - {} : {}'.format(k, v)) logging('=' * 100) ############################################################################ # Training ############################################################################ # linear cosine annealing kld_weight = min(1., args.init_kld + train_step * args.kld_incr) loss_sum = torch.Tensor([0]).to(args.device) kld_sum = torch.Tensor([0]).to(args.device) nll_sum = torch.Tensor([0]).to(args.device) gnorm_sum = 0 t = timeit.default_timer() for epoch in range(args.num_epochs): model.train() # make sure all data iterators use the same seed to shuffle data if args.distributed: np.random.seed(args.seed + epoch) #initalize the hidden state if args.pass_h: hidden = model.init_hidden(args.batch_size) else: hidden = None for x, y, mask in tr_data.get_masked_iter(shuffle=True): opt.zero_grad() ratio = 1. / torch.sum(mask) if args.kld: nll_loss, kld_loss, hidden = para_model(x, y, mask=mask, hidden=hidden) nll_loss = nll_loss.sum() * ratio kld_loss = kld_loss.sum() * ratio train_loss = nll_loss - kld_loss * kld_weight train_loss.backward() total_loss = nll_loss.detach() - kld_loss.detach() kld_sum += -kld_loss.detach() nll_sum += nll_loss.detach() else: nll_loss, hidden = para_model(x, y, mask=mask, hidden=hidden) train_loss = nll_loss.sum() * ratio train_loss.backward() total_loss = train_loss.detach() if args.clip > 0: gnorm = nn.utils.clip_grad_norm_(model.parameters(), args.clip) else: gnorm = 0 for n, p in model.named_parameters(): param_gnorm = p.grad.data.norm(2) gnorm += param_gnorm.item()**2 gnorm = gnorm**(1. / 2) opt.step() gnorm_sum += gnorm loss_sum += total_loss train_step += 1 # lr & kl annealling kld_weight = min(1., kld_weight + args.kld_incr) adjust_lr(opt, train_step, args.max_step, args.lr, args.end_lr) # log training if train_step % args.log_interval == 0: if args.distributed: dist.reduce(loss_sum, dst=0, op=dist.ReduceOp.SUM) loss_sum = loss_sum.div_(args.distributed_world_size) dist.reduce(nll_sum, dst=0, op=dist.ReduceOp.SUM) nll_sum = nll_sum.div_(args.distributed_world_size) dist.reduce(kld_sum, dst=0, op=dist.ReduceOp.SUM) kld_sum = kld_sum.div_(args.distributed_world_size) if is_master(args): cur_loss = loss_sum.item() / args.log_interval cur_nll = nll_sum.item() / args.log_interval cur_kld = kld_sum.item() / args.log_interval elapsed = (timeit.default_timer() - t) / 3600 logging('| total hrs [{:.2f}] | epoch {} step {} ' \ '| lr {:8.6f}, klw {:7.5f} | LL {:>9.4f} ' \ '| nll_loss {:>7.4f}, kld_loss {:>8.4f} ' \ '| gnorm {:.4f}'.format( elapsed, epoch, train_step, opt.param_groups[0]['lr'], kld_weight, -cur_loss, cur_nll, cur_kld, gnorm_sum / args.log_interval)) loss_sum = torch.Tensor([0]).to(args.device) kld_sum = torch.Tensor([0]).to(args.device) nll_sum = torch.Tensor([0]).to(args.device) gnorm_sum = 0 # validation if train_step % args.eval_interval == 0: eval_ll = evaluate(va_data, model, args) if is_master(args): logging('-' * 120) logging('Eval [{}] at step: {} | valid LL: {:>8.4f}'. format(train_step // args.eval_interval, train_step, eval_ll)) if eval_ll > best_eval_ll: best_eval_ll = eval_ll if not args.debug: logging('Save checkpoint. ' \ 'Best valid LL {:>9.4f}'.format(eval_ll)) torch.save([model, opt], args.model_path) torch.save([best_eval_ll, train_step], args.var_path) logging('-' * 120) # Reach maximum training step if train_step == args.max_step: break if train_step == args.max_step: break eval_ll = evaluate(va_data, model, args) if is_master(args): logging('-' * 120) logging('Eval [{}] | step: {}, LL: {:>8.4f}'.format( train_step // args.eval_interval, train_step, eval_ll)) logging('-' * 120) # evaluate the current model test_loss = evaluate(te_data, model, args) if is_master(args): logging('Test -- LL: {:>8.4f}'.format(test_loss))
def main(args, init_distributed=False): assert args.max_tokens is not None or args.max_sentences is not None, \ 'Must specify batch size either with --max-tokens or --max-sentences' if torch.cuda.is_available() and not args.cpu: torch.cuda.set_device(args.device_id) # set random seed np.random.seed(args.seed) torch.manual_seed(args.seed) if init_distributed: args.distributed_rank = distributed_utils.distributed_init(args) if distributed_utils.is_master(args): checkpoint_utils.verify_checkpoint_directory(args.save_dir) print(args, flush=True) # Setup task, e.g., translation, language modeling, etc. task = None if args.task == 'bert': task = tasks.LanguageModelingTask.setup_task(args) elif args.task == 'mnist': task = tasks.MNISTTask.setup_task(args) assert task != None # Load valid dataset (we load training data below, based on the latest checkpoint) for valid_sub_split in args.valid_subset.split(','): task.load_dataset(valid_sub_split, combine=False, epoch=0) # Build model model = task.build_model(args) print('| num. model params: {} (num. trained: {})'.format( sum(p.numel() for p in model.parameters()), sum(p.numel() for p in model.parameters() if p.requires_grad), )) # Build controller controller = Controller(args, task, model) print('| training on {} GPUs'.format(args.distributed_world_size)) print('| max tokens per GPU = {} and max sentences per GPU = {}'.format( args.max_tokens, args.max_sentences, )) # Load the latest checkpoint if one is available and restore the # corresponding train iterator extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, controller) # Train until the learning rate gets too small max_epoch = args.max_epoch or math.inf max_update = args.max_update or math.inf lr = controller.get_lr() train_meter = StopwatchMeter() train_meter.start() while (lr > args.min_lr and (epoch_itr.epoch < max_epoch or (epoch_itr.epoch == max_epoch and epoch_itr._next_epoch_itr is not None)) and controller.get_num_updates() < max_update): # train for one epoch train(args, controller, task, epoch_itr) # #revise-task 6 # debug valid_losses = [None] # only use first validation loss to update the learning rate lr = controller.lr_step(epoch_itr.epoch, valid_losses[0]) # save checkpoint if epoch_itr.epoch % args.save_interval == 0: checkpoint_utils.save_checkpoint(args, controller, epoch_itr, valid_losses[0]) reload_dataset = ':' in getattr(args, 'data', '') # sharded data: get train iterator for next epoch epoch_itr = controller.get_train_iterator(epoch_itr.epoch, load_dataset=reload_dataset) train_meter.stop() print('| done training in {:.1f} seconds'.format(train_meter.sum))
def stat(args, data): lengths = [len(x[0]) for x in data] if is_master(args): print("Number of sent: {}".format(len(data))) print("Sent length: mean {}, std {}, max {}".format( np.mean(lengths), np.std(lengths), np.max(lengths)))
def convert_examples_to_tensors(args, examples, label_dict, tokenizer, output_file): """Encode and cache raw data into pytorch format.""" if not is_master(args) and args.distributed: torch.distributed.barrier() if not os.path.exists(output_file) or args.overwrite_data: sents, labels, seg_ids = [], [], [] for (ex_index, example) in enumerate(examples): example_len = 0 tokens_a = tokenizer.convert_text_to_ids(example.text_a) example_len += len(tokens_a) tokens_b = None if example.text_b: tokens_b = tokenizer.convert_text_to_ids(example.text_b) example_len += len(tokens_b) if tokens_b: # Modifies `tokens_a` and `tokens_b` in place so that the total # length is less than the specified length. # Account for two [SEP] & one [CLS] with "- 3" _truncate_seq_pair(tokens_a, tokens_b, args.max_length - 3) else: # Account for one [SEP] & one [CLS] with "- 2" if len(tokens_a) > args.max_length - 2: tokens_a = tokens_a[:args.max_length - 2] input_ids = [] segment_ids = [] if tokens_b is not None: input_ids = ([args.cls_id] + tokens_a + [args.sep_id] + tokens_b + [args.sep_id]) segment_ids = ([args.seg_id_cls] + [args.seg_id_a] * (len(tokens_a) + 1) + [args.seg_id_b] * (len(tokens_b) + 1)) else: input_ids = [args.cls_id] + tokens_a + [args.sep_id] segment_ids = ([args.seg_id_cls] + [args.seg_id_a] * (len(tokens_a) + 1)) # Label if label_dict is not None: label_id = label_dict[example.label] else: label_id = example.label input_ids = torch.LongTensor(input_ids) segment_ids = torch.LongTensor(segment_ids) sents.append(input_ids) seg_ids.append(segment_ids) labels.append(label_id) data = list(zip(sents, seg_ids, labels)) torch.save(data, output_file) else: data = torch.load(output_file) if is_master(args) and args.distributed: torch.distributed.barrier() stat(args, data) return data
def main(args): """Main training function.""" torch.cuda.set_device(args.device_id) if args.distributed: args.distributed_rank = args.device_id distributed_init(args) if args.seed is not None: np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) options.setup_device(args) ############################################################################ # Experiment & Logging ############################################################################ if is_master(args): if args.resume: # rank-0 device creates experiment dir and log to the file logging = utils.get_logger(os.path.join(args.model_dir, "log.txt"), log_=not args.debug) else: # rank-0 device creates experiment dir and log to the file logging = utils.create_exp_dir(args.model_dir, debug=args.debug) else: # other devices only log to console (print) but not the file logging = utils.get_logger(log_path=None, log_=False) ############################################################################ # Load data ############################################################################ logging("Loading data..") loaded_data, label_dict = data.load_data(args) args.num_class = len(label_dict) logging("Loading finish") tr_data, va_data, te_data = loaded_data va_loader = data.BucketIterator(va_data, args.valid_bsz, args.pad_id, args.seg_id_pad, args.device, args.max_length) te_loader = data.BucketIterator(te_data, args.test_bsz, args.pad_id, args.seg_id_pad, args.device, args.max_length) options.setup_device(args) args.model_path = os.path.join(args.model_dir, "model.pt") args.var_path = os.path.join(args.model_dir, "var.pt") args.config_path = os.path.join(args.model_dir, "net_config.json") train_step = 0 best_accuracy = -float("inf") # create model if args.resume: logging("Resuming from {}...".format(args.model_dir)) net_config = modeling.ModelConfig.init_from_json( args.config_path, args) model = modeling.FunnelTFM(net_config, args) model_param, optimizer = torch.load(args.model_path, map_location="cpu") logging(model.load_state_dict(model_param, strict=False)) model = model.to(args.device) for state in optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.to(args.device) best_accuracy, train_step = torch.load(args.var_path) else: # create new model if args.init_ckpt: logging("Init from ckpt {}".format(args.init_ckpt)) net_config = modeling.ModelConfig.init_from_json( args.init_ckpt_config, args) model = modeling.FunnelTFM(net_config, args) print( model.load_state_dict(torch.load(args.init_ckpt), strict=False)) else: logging("init model") net_config = modeling.ModelConfig.init_from_args(args) model = modeling.FunnelTFM(net_config, args) net_config.to_json(args.config_path) model = model.to(args.device) # create new optimizer if args.fp16: from apex.optimizers import FusedAdam import apex.amp as amp optimizer = FusedAdam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) amp_model, optimizer = amp.initialize(model, optimizer, opt_level=args.amp_opt) else: try: from apex.optimizers import FusedAdam optimizer = FusedAdam(model.parameters(), lr=args.lr, betas=(0.9, 0.99), eps=1e-6, weight_decay=args.weight_decay) except ImportError as e: logging("use pytorch optimizer") optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.99), eps=1e-6, weight_decay=args.weight_decay) amp_model = model if args.distributed: if args.ddp_backend == "apex": from apex.parallel import DistributedDataParallel as DDP para_model = DDP(amp_model) else: from torch.nn.parallel import DistributedDataParallel as DDP para_model = DDP(amp_model, device_ids=[args.device_id], find_unused_parameters=True) else: para_model = amp_model ############################################################################ # Log args ############################################################################ logging("=" * 100) for k, v in args.__dict__.items(): logging(" - {} : {}".format(k, v)) logging("=" * 100) ############################################################################ # Training ############################################################################ if not args.test_only: tr_loader = data.BucketIterator(tr_data, args.train_bsz, args.pad_id, args.seg_id_pad, args.device, args.max_length) if args.distributed: num_data = len(tr_data) // args.distributed_world_size else: num_data = len(tr_data) num_tr_batch = (num_data + args.train_bsz - 1) // args.train_bsz args.train_steps = num_tr_batch * args.epochs args.warmup_steps = int(args.train_steps * args.warmup_prop) num_example = torch.Tensor([0]).to(args.device) num_correct = torch.Tensor([0]).to(args.device) if args.dataset in ["CoLA"]: num_tp = torch.Tensor([0]).to(args.device) num_fp = torch.Tensor([0]).to(args.device) num_tn = torch.Tensor([0]).to(args.device) num_fn = torch.Tensor([0]).to(args.device) for epoch in range(args.epochs): #### One epoch for i, (sent, seg_id, label) in enumerate( tr_loader.get_iter(epoch, distributed=args.distributed)): optimizer.zero_grad() _, ret_dict = para_model(sent, seg_id=seg_id, cls_target=label) cls_loss = ret_dict["cls_loss"] cls_corr = ret_dict["cls_corr"] if args.fp16: with amp.scale_loss(cls_loss, optimizer) as scaled_loss: scaled_loss.backward() else: cls_loss.backward() num_correct += cls_corr.detach() num_example += len(sent) if args.dataset in ["CoLA"]: tp, fp, tn, fn = confusion_matrix(ret_dict["cls_pred"], label) num_tp = num_tp + tp num_fp = num_fp + fp num_tn = num_tn + tn num_fn = num_fn + fn if args.clip > 0: if args.fp16: gnorm = torch.nn.utils.clip_grad_norm_( amp.master_params(optimizer), args.clip) else: gnorm = torch.nn.utils.clip_grad_norm_( model.parameters(), args.clip) else: gnorm = 0 for p in model.parameters(): if p.grad is not None: param_gnorm = p.grad.data.norm(2) gnorm += param_gnorm.item()**2 gnorm = gnorm**(1. / 2) train_step += 1 adjust_lr(args, train_step, optimizer) optimizer.step() ##### training stat if (i + 1) % (num_tr_batch // args.n_log_epoch) == 0: if args.distributed: torch.distributed.all_reduce( num_correct, op=torch.distributed.ReduceOp.SUM) torch.distributed.all_reduce( num_example, op=torch.distributed.ReduceOp.SUM) if args.dataset in ["CoLA"]: torch.distributed.all_reduce( num_tp, op=torch.distributed.ReduceOp.SUM) torch.distributed.all_reduce( num_fp, op=torch.distributed.ReduceOp.SUM) torch.distributed.all_reduce( num_tn, op=torch.distributed.ReduceOp.SUM) torch.distributed.all_reduce( num_fn, op=torch.distributed.ReduceOp.SUM) if is_master(args): if args.dataset in ["CoLA"]: corref = _compute_metric_based_on_keys( "corr", num_tp.item(), num_fp.item(), num_tn.item(), num_fn.item()) logging( "[{:>02d}/{:>08d}] Train | corref {:.4f} | gnorm {:.2f} " "| lr {:.6f}".format( epoch, train_step, corref, gnorm, optimizer.param_groups[0]["lr"])) else: accuracy = num_correct.item() / num_example.item() logging( "[{:>02d}/{:>08d}] Train | accu {:.4f} | gnorm {:.2f} " "| lr {:.6f}".format( epoch, train_step, accuracy, gnorm, optimizer.param_groups[0]["lr"])) num_example.zero_() num_correct.zero_() if args.dataset in ["CoLA"]: num_tp.zero_() num_fp.zero_() num_tn.zero_() num_fn.zero_() ##### validation if train_step % (args.train_steps // 10) == 0: accuracy = evaluate(args, model, va_loader) if is_master(args): if accuracy > best_accuracy: torch.save([model.state_dict(), optimizer], args.model_path) torch.save([best_accuracy, train_step], args.var_path) best_accuracy = max(accuracy, best_accuracy) logging( "[{}] Valid | curr accu {:.4f} | best accu {:.4f}". format(train_step // (args.train_steps // 10), accuracy, best_accuracy)) ##### make prediction if is_master(args) and args.write_prediction: rev_label_dict = dict((v, k) for k, v in label_dict.items()) model.load_state_dict(torch.load(args.model_path, map_location="cpu")[0], strict=False) model = model.to(args.device) predict(args, model, te_loader, os.path.join(args.model_dir, "test_results.txt"), rev_label_dict) predict(args, model, va_loader, os.path.join(args.model_dir, "valid_results.txt"), rev_label_dict)
def save_checkpoint(args, controller, epoch_itr, val_loss): import distributed_utils, meters prev_best = getattr(save_checkpoint, 'best', val_loss) if val_loss is not None: best_function = max if args.maximize_best_checkpoint_metric else min save_checkpoint.best = best_function(val_loss, prev_best) if args.no_save or not distributed_utils.is_master(args): return def is_better(a, b): return a >= b if args.maximize_best_checkpoint_metric else a <= b write_timer = meters.StopwatchMeter() write_timer.start() epoch = epoch_itr.epoch end_of_epoch = epoch_itr.end_of_epoch() updates = controller.get_num_updates() checkpoint_conds = collections.OrderedDict() checkpoint_conds['checkpoint{}.pt'.format(epoch)] = ( end_of_epoch and not args.no_epoch_checkpoints and epoch % args.save_interval == 0) checkpoint_conds['checkpoint_{}_{}.pt'.format( epoch, updates)] = (not end_of_epoch and args.save_interval_updates > 0 and updates % args.save_interval_updates == 0) checkpoint_conds['checkpoint_best.pt'] = ( val_loss is not None and (not hasattr(save_checkpoint, 'best') or is_better(val_loss, save_checkpoint.best))) checkpoint_conds['checkpoint_last.pt'] = not args.no_last_checkpoints extra_state = { 'train_iterator': epoch_itr.state_dict(), 'val_loss': val_loss, } if hasattr(save_checkpoint, 'best'): extra_state.update({'best': save_checkpoint.best}) checkpoints = [ os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond ] if len(checkpoints) > 0: controller.save_checkpoint(checkpoints[0], extra_state) for cp in checkpoints[1:]: shutil.copyfile(checkpoints[0], cp) write_timer.stop() print( '| saved checkpoint {} (epoch {} @ {} updates) (writing took {} seconds)' .format(checkpoints[0], epoch, updates, write_timer.sum)) if not end_of_epoch and args.keep_interval_updates > 0: # remove old checkpoints; checkpoints are sorted in descending order checkpoints = checkpoint_paths( args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt', ) for old_chk in checkpoints[args.keep_interval_updates:]: if os.path.lexists(old_chk): os.remove(old_chk) if args.keep_last_epochs > 0: # remove old epoch checkpoints; checkpoints are sorted in descending order checkpoints = checkpoint_paths( args.save_dir, pattern=r'checkpoint(\d+)\.pt', ) for old_chk in checkpoints[args.keep_last_epochs:]: if os.path.lexists(old_chk): os.remove(old_chk)