Beispiel #1
0
def get_game(params: argparse.Namespace, checkpoint_path: str):
    game = build_game(params)
    checkpoint = torch.load(checkpoint_path)
    game.load_state_dict(checkpoint.model_state_dict)
    return game
Beispiel #2
0
def main(params):
    opts = get_common_opts(params=params)
    print(f"{opts}\n")
    assert not opts.batch_size % 2, (
        f"Batch size must be multiple of 2. Found {opts.batch_size} instead")
    print(
        f"Running a distruted training is set to: {opts.distributed_context.is_distributed}. "
        f"World size is {opts.distributed_context.world_size}. "
        f"Using batch of size {opts.batch_size} on {opts.distributed_context.world_size} device(s)\n"
        f"Applying augmentations: {opts.use_augmentations} with image size: {opts.image_size}.\n"
    )
    if not opts.distributed_context.is_distributed and opts.pdb:
        breakpoint()

    train_loader = get_dataloader(
        dataset_dir=opts.dataset_dir,
        image_size=opts.image_size,
        batch_size=opts.batch_size,
        num_workers=opts.num_workers,
        use_augmentations=opts.use_augmentations,
        is_distributed=opts.distributed_context.is_distributed,
        return_original_image=opts.return_original_image,
        seed=opts.random_seed)

    game = build_game(opts)

    model_parameters = add_weight_decay(game,
                                        opts.weight_decay,
                                        skip_name='bn')

    optimizer = torch.optim.SGD(
        model_parameters,
        lr=opts.lr,
        momentum=0.9,
    )
    optimizer_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=opts.n_epochs)

    if opts.distributed_context.is_distributed and opts.distributed_context.world_size > 2 and opts.use_larc:
        optimizer = LARC(optimizer,
                         trust_coefficient=0.001,
                         clip=False,
                         eps=1e-8)

    callbacks = get_callbacks(
        shared_vision=opts.shared_vision,
        n_epochs=opts.n_epochs,
        checkpoint_dir=opts.checkpoint_dir,
        sender=game.game.sender,
        train_gs_temperature=opts.train_gs_temperature,
        minimum_gs_temperature=opts.minimum_gs_temperature,
        update_gs_temp_frequency=opts.update_gs_temp_frequency,
        gs_temperature_decay=opts.gs_temperature_decay,
        is_distributed=opts.distributed_context.is_distributed)

    trainer = core.Trainer(
        game=game,
        optimizer=optimizer,
        optimizer_scheduler=optimizer_scheduler,
        train_data=train_loader,
        callbacks=callbacks,
    )
    trainer.train(n_epochs=opts.n_epochs)

    print("| FINISHED JOB")
Beispiel #3
0
def main(params):
    opts = get_common_opts(params=params)
    print(f"{opts}\n")
    assert (
        not opts.batch_size % 2
    ), f"Batch size must be multiple of 2. Found {opts.batch_size} instead"
    print(
        f"Running a distruted training is set to: {opts.distributed_context.is_distributed}. "
        f"World size is {opts.distributed_context.world_size}. "
        f"Using batch of size {opts.batch_size} on {opts.distributed_context.world_size} device(s)\n"
        f"Applying augmentations: {opts.use_augmentations} with image size: {opts.image_size}.\n"
    )
    if not opts.distributed_context.is_distributed and opts.pdb:
        breakpoint()
    if opts.use_distributed_negatives and not opts.distributed_context.is_distributed:
        sys.exit("Distributed negatives cannot be used in non-distributed context")

    train_loader = get_dataloader(
        dataset_dir=opts.dataset_dir,
        dataset_name=opts.dataset_name,
        image_size=opts.image_size,
        batch_size=opts.batch_size,
        num_workers=opts.num_workers,
        is_distributed=opts.distributed_context.is_distributed,
        seed=opts.random_seed,
        use_augmentations=opts.use_augmentations,
        return_original_image=opts.return_original_image,
    )

    game = build_game(opts)

    model_parameters = add_weight_decay(game, opts.weight_decay, skip_name="bn")

    optimizer = torch.optim.SGD(
        model_parameters,
        lr=opts.lr,
        momentum=0.9,
    )
    optimizer_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=opts.n_epochs
    )

    if (
        opts.distributed_context.is_distributed
        and opts.distributed_context.world_size > 2
        and opts.use_larc
    ):
        optimizer = LARC(optimizer, trust_coefficient=0.001, clip=False, eps=1e-8)

    callbacks = get_callbacks(
        shared_vision=opts.shared_vision,
        n_epochs=opts.n_epochs,
        checkpoint_dir=opts.checkpoint_dir,
        sender=game.game.sender,
        train_gs_temperature=opts.train_gs_temperature,
        minimum_gs_temperature=opts.minimum_gs_temperature,
        update_gs_temp_frequency=opts.update_gs_temp_frequency,
        gs_temperature_decay=opts.gs_temperature_decay,
        is_distributed=opts.distributed_context.is_distributed,
    )

    trainer = core.Trainer(
        game=game,
        optimizer=optimizer,
        optimizer_scheduler=optimizer_scheduler,
        train_data=train_loader,
        callbacks=callbacks,
    )
    trainer.train(n_epochs=opts.n_epochs)

    data_args = {
        "image_size": opts.image_size,
        "batch_size": opts.batch_size,
        "dataset_name": "imagenet",
        "num_workers": opts.num_workers,
        "use_augmentations": False,
        "is_distributed": opts.distributed_context.is_distributed,
        "seed": opts.random_seed,
    }
    i_test_loader = get_dataloader(
        dataset_dir="/datasets01/imagenet_full_size/061417/val", **data_args
    )
    o_test_loader = get_dataloader(
        dataset_dir="/private/home/mbaroni/agentini/representation_learning/generalizaton_set_construction/80_generalization_data_set/",
        **data_args,
    )

    _, i_test_interaction = trainer.eval(i_test_loader)
    dump = dict((k, v.mean().item()) for k, v in i_test_interaction.aux.items())
    dump.update(dict(mode="VALIDATION_I_TEST"))
    print(json.dumps(dump), flush=True)

    _, o_test_interaction = trainer.eval(o_test_loader)
    dump = dict((k, v.mean().item()) for k, v in o_test_interaction.aux.items())
    dump.update(dict(mode="VALIDATION_O_TEST"))
    print(json.dumps(dump), flush=True)

    if opts.checkpoint_dir:
        output_path = Path(opts.checkpoint_dir)
        output_path.mkdir(exist_ok=True, parents=True)
        torch.save(i_test_interaction, output_path / "i_test_interaction")
        torch.save(o_test_interaction, output_path / "o_test_interaction")

    print("| FINISHED JOB")