def _save_averaged_checkpoint(args, extra_state): epoch, offset = extra_state["epoch"], extra_state["batch_offset"] if not hasattr(_save_averaged_checkpoint, "last_avg_checkpoints"): if args.max_checkpoints_kept == 0: raise argparse.ArgumentTypeError("--max-checkpoints-kept must be != 0.") _save_averaged_checkpoint.last_avg_checkpoints = ManagedCheckpoints( max(args.max_checkpoints_kept, 1), auto_clear=args.max_checkpoints_kept > 0 ) last_checkpoints = extra_state["last_checkpoints"].get_last_n( 1 if args.no_epoch_checkpoints else args.generate_bleu_eval_avg_checkpoints ) if args.log_verbose: print( f"Reading {len(last_checkpoints)} previous " f"checkpoints for averaging in epoch {epoch}, offset {offset}.", flush=True, ) averaged_state = average_checkpoints.average_checkpoints(last_checkpoints) filename = os.path.join(args.save_dir, f"averaged_checkpoint{epoch}_{offset}.pt") _save_averaged_checkpoint.last_avg_checkpoints.append(filename) if args.log_verbose: print( f"Preparing to save averaged checkpoint for " f"epoch {epoch}, offset {offset}.", flush=True, ) torch.save(averaged_state, filename) if args.log_verbose: print( f"Finished saving averaged checkpoint for " f"epoch {epoch}, offset {offset}.", flush=True, ) return filename
def test_average_checkpoints(self): params_0 = collections.OrderedDict( [ ("a", torch.DoubleTensor([100.0])), ("b", torch.FloatTensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])), ("c", torch.IntTensor([7, 8, 9])), ] ) params_1 = collections.OrderedDict( [ ("a", torch.DoubleTensor([1.0])), ("b", torch.FloatTensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]])), ("c", torch.IntTensor([2, 2, 2])), ] ) params_avg = collections.OrderedDict( [ ("a", torch.DoubleTensor([50.5])), ("b", torch.FloatTensor([[1.0, 1.5, 2.0], [2.5, 3.0, 3.5]])), # We expect truncation for integer division ("c", torch.IntTensor([4, 5, 5])), ] ) fd_0, path_0 = tempfile.mkstemp() fd_1, path_1 = tempfile.mkstemp() torch.save(collections.OrderedDict([("model", params_0)]), path_0) torch.save(collections.OrderedDict([("model", params_1)]), path_1) output = average_checkpoints([path_0, path_1])["model"] os.close(fd_0) os.remove(path_0) os.close(fd_1) os.remove(path_1) for (k_expected, v_expected), (k_out, v_out) in zip( params_avg.items(), output.items() ): self.assertEquals( k_expected, k_out, f"Key mismatch - expected {k_expected} but found {k_out}. " "(Expected list of keys: {params_avg.keys()} " "vs actual list of keys: {output.keys()})", ) np.testing.assert_allclose( v_expected.numpy(), v_out.numpy(), err_msg=f"Tensor value mismatch for key {k_expected}", )
def _save_averaged_checkpoint(args, epoch, offset): if args.log_verbose: print( f'Reading {len(save_checkpoint.last_checkpoints)} previous ' f'checkpoints for averaging in epoch {epoch}, offset {offset}.', flush=True) averaged_state = average_checkpoints.average_checkpoints( save_checkpoint.last_checkpoints) filename = os.path.join(args.save_dir, f'averaged_checkpoint{epoch}_{offset}.pt') if args.log_verbose: print( f'Preparing to save averaged checkpoint for ' f'epoch {epoch}, offset {offset}.', flush=True) torch.save(averaged_state, filename) if args.log_verbose: print( f'Finished saving averaged checkpoint for ' f'epoch {epoch}, offset {offset}.', flush=True) return filename