def get_game(params: argparse.Namespace, checkpoint_path: str): game = build_game(params) checkpoint = torch.load(checkpoint_path) game.load_state_dict(checkpoint.model_state_dict) return game
def main(params): opts = get_common_opts(params=params) print(f"{opts}\n") assert not opts.batch_size % 2, ( f"Batch size must be multiple of 2. Found {opts.batch_size} instead") print( f"Running a distruted training is set to: {opts.distributed_context.is_distributed}. " f"World size is {opts.distributed_context.world_size}. " f"Using batch of size {opts.batch_size} on {opts.distributed_context.world_size} device(s)\n" f"Applying augmentations: {opts.use_augmentations} with image size: {opts.image_size}.\n" ) if not opts.distributed_context.is_distributed and opts.pdb: breakpoint() train_loader = get_dataloader( dataset_dir=opts.dataset_dir, image_size=opts.image_size, batch_size=opts.batch_size, num_workers=opts.num_workers, use_augmentations=opts.use_augmentations, is_distributed=opts.distributed_context.is_distributed, return_original_image=opts.return_original_image, seed=opts.random_seed) game = build_game(opts) model_parameters = add_weight_decay(game, opts.weight_decay, skip_name='bn') optimizer = torch.optim.SGD( model_parameters, lr=opts.lr, momentum=0.9, ) optimizer_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=opts.n_epochs) if opts.distributed_context.is_distributed and opts.distributed_context.world_size > 2 and opts.use_larc: optimizer = LARC(optimizer, trust_coefficient=0.001, clip=False, eps=1e-8) callbacks = get_callbacks( shared_vision=opts.shared_vision, n_epochs=opts.n_epochs, checkpoint_dir=opts.checkpoint_dir, sender=game.game.sender, train_gs_temperature=opts.train_gs_temperature, minimum_gs_temperature=opts.minimum_gs_temperature, update_gs_temp_frequency=opts.update_gs_temp_frequency, gs_temperature_decay=opts.gs_temperature_decay, is_distributed=opts.distributed_context.is_distributed) trainer = core.Trainer( game=game, optimizer=optimizer, optimizer_scheduler=optimizer_scheduler, train_data=train_loader, callbacks=callbacks, ) trainer.train(n_epochs=opts.n_epochs) print("| FINISHED JOB")
def main(params): opts = get_common_opts(params=params) print(f"{opts}\n") assert ( not opts.batch_size % 2 ), f"Batch size must be multiple of 2. Found {opts.batch_size} instead" print( f"Running a distruted training is set to: {opts.distributed_context.is_distributed}. " f"World size is {opts.distributed_context.world_size}. " f"Using batch of size {opts.batch_size} on {opts.distributed_context.world_size} device(s)\n" f"Applying augmentations: {opts.use_augmentations} with image size: {opts.image_size}.\n" ) if not opts.distributed_context.is_distributed and opts.pdb: breakpoint() if opts.use_distributed_negatives and not opts.distributed_context.is_distributed: sys.exit("Distributed negatives cannot be used in non-distributed context") train_loader = get_dataloader( dataset_dir=opts.dataset_dir, dataset_name=opts.dataset_name, image_size=opts.image_size, batch_size=opts.batch_size, num_workers=opts.num_workers, is_distributed=opts.distributed_context.is_distributed, seed=opts.random_seed, use_augmentations=opts.use_augmentations, return_original_image=opts.return_original_image, ) game = build_game(opts) model_parameters = add_weight_decay(game, opts.weight_decay, skip_name="bn") optimizer = torch.optim.SGD( model_parameters, lr=opts.lr, momentum=0.9, ) optimizer_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=opts.n_epochs ) if ( opts.distributed_context.is_distributed and opts.distributed_context.world_size > 2 and opts.use_larc ): optimizer = LARC(optimizer, trust_coefficient=0.001, clip=False, eps=1e-8) callbacks = get_callbacks( shared_vision=opts.shared_vision, n_epochs=opts.n_epochs, checkpoint_dir=opts.checkpoint_dir, sender=game.game.sender, train_gs_temperature=opts.train_gs_temperature, minimum_gs_temperature=opts.minimum_gs_temperature, update_gs_temp_frequency=opts.update_gs_temp_frequency, gs_temperature_decay=opts.gs_temperature_decay, is_distributed=opts.distributed_context.is_distributed, ) trainer = core.Trainer( game=game, optimizer=optimizer, optimizer_scheduler=optimizer_scheduler, train_data=train_loader, callbacks=callbacks, ) trainer.train(n_epochs=opts.n_epochs) data_args = { "image_size": opts.image_size, "batch_size": opts.batch_size, "dataset_name": "imagenet", "num_workers": opts.num_workers, "use_augmentations": False, "is_distributed": opts.distributed_context.is_distributed, "seed": opts.random_seed, } i_test_loader = get_dataloader( dataset_dir="/datasets01/imagenet_full_size/061417/val", **data_args ) o_test_loader = get_dataloader( dataset_dir="/private/home/mbaroni/agentini/representation_learning/generalizaton_set_construction/80_generalization_data_set/", **data_args, ) _, i_test_interaction = trainer.eval(i_test_loader) dump = dict((k, v.mean().item()) for k, v in i_test_interaction.aux.items()) dump.update(dict(mode="VALIDATION_I_TEST")) print(json.dumps(dump), flush=True) _, o_test_interaction = trainer.eval(o_test_loader) dump = dict((k, v.mean().item()) for k, v in o_test_interaction.aux.items()) dump.update(dict(mode="VALIDATION_O_TEST")) print(json.dumps(dump), flush=True) if opts.checkpoint_dir: output_path = Path(opts.checkpoint_dir) output_path.mkdir(exist_ok=True, parents=True) torch.save(i_test_interaction, output_path / "i_test_interaction") torch.save(o_test_interaction, output_path / "o_test_interaction") print("| FINISHED JOB")