Beispiel #1
0
def _save_averaged_checkpoint(args, extra_state):
    epoch, offset = extra_state["epoch"], extra_state["batch_offset"]
    if not hasattr(_save_averaged_checkpoint, "last_avg_checkpoints"):
        if args.max_checkpoints_kept == 0:
            raise argparse.ArgumentTypeError("--max-checkpoints-kept must be != 0.")
        _save_averaged_checkpoint.last_avg_checkpoints = ManagedCheckpoints(
            max(args.max_checkpoints_kept, 1), auto_clear=args.max_checkpoints_kept > 0
        )

    last_checkpoints = extra_state["last_checkpoints"].get_last_n(
        1 if args.no_epoch_checkpoints else args.generate_bleu_eval_avg_checkpoints
    )
    if args.log_verbose:
        print(
            f"Reading {len(last_checkpoints)} previous "
            f"checkpoints for averaging in epoch {epoch}, offset {offset}.",
            flush=True,
        )
    averaged_state = average_checkpoints.average_checkpoints(last_checkpoints)
    filename = os.path.join(args.save_dir, f"averaged_checkpoint{epoch}_{offset}.pt")
    _save_averaged_checkpoint.last_avg_checkpoints.append(filename)
    if args.log_verbose:
        print(
            f"Preparing to save averaged checkpoint for "
            f"epoch {epoch}, offset {offset}.",
            flush=True,
        )
    torch.save(averaged_state, filename)
    if args.log_verbose:
        print(
            f"Finished saving averaged checkpoint for "
            f"epoch {epoch}, offset {offset}.",
            flush=True,
        )
    return filename
Beispiel #2
0
    def test_average_checkpoints(self):
        params_0 = collections.OrderedDict(
            [
                ("a", torch.DoubleTensor([100.0])),
                ("b", torch.FloatTensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])),
                ("c", torch.IntTensor([7, 8, 9])),
            ]
        )
        params_1 = collections.OrderedDict(
            [
                ("a", torch.DoubleTensor([1.0])),
                ("b", torch.FloatTensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]])),
                ("c", torch.IntTensor([2, 2, 2])),
            ]
        )
        params_avg = collections.OrderedDict(
            [
                ("a", torch.DoubleTensor([50.5])),
                ("b", torch.FloatTensor([[1.0, 1.5, 2.0], [2.5, 3.0, 3.5]])),
                # We expect truncation for integer division
                ("c", torch.IntTensor([4, 5, 5])),
            ]
        )

        fd_0, path_0 = tempfile.mkstemp()
        fd_1, path_1 = tempfile.mkstemp()
        torch.save(collections.OrderedDict([("model", params_0)]), path_0)
        torch.save(collections.OrderedDict([("model", params_1)]), path_1)

        output = average_checkpoints([path_0, path_1])["model"]

        os.close(fd_0)
        os.remove(path_0)
        os.close(fd_1)
        os.remove(path_1)

        for (k_expected, v_expected), (k_out, v_out) in zip(
            params_avg.items(), output.items()
        ):
            self.assertEquals(
                k_expected,
                k_out,
                f"Key mismatch - expected {k_expected} but found {k_out}. "
                "(Expected list of keys: {params_avg.keys()} "
                "vs actual list of keys: {output.keys()})",
            )
            np.testing.assert_allclose(
                v_expected.numpy(),
                v_out.numpy(),
                err_msg=f"Tensor value mismatch for key {k_expected}",
            )
Beispiel #3
0
def _save_averaged_checkpoint(args, epoch, offset):
    if args.log_verbose:
        print(
            f'Reading {len(save_checkpoint.last_checkpoints)} previous '
            f'checkpoints for averaging in epoch {epoch}, offset {offset}.',
            flush=True)
    averaged_state = average_checkpoints.average_checkpoints(
        save_checkpoint.last_checkpoints)
    filename = os.path.join(args.save_dir,
                            f'averaged_checkpoint{epoch}_{offset}.pt')
    if args.log_verbose:
        print(
            f'Preparing to save averaged checkpoint for '
            f'epoch {epoch}, offset {offset}.',
            flush=True)
    torch.save(averaged_state, filename)
    if args.log_verbose:
        print(
            f'Finished saving averaged checkpoint for '
            f'epoch {epoch}, offset {offset}.',
            flush=True)
    return filename