def _train_fn_amp(local_rank, world_size): process_group_kwargs = { "backend": "nccl", "world_size": world_size, } os.environ["WORLD_SIZE"] = str(world_size) os.environ["RANK"] = str(local_rank) os.environ["LOCAL_RANK"] = str(local_rank) dist.init_process_group(**process_group_kwargs) train_experiment(dl.Engine(fp16=True)) dist.destroy_process_group()
if __name__ == "__main__": set_global_seed(42) train_dataset = MovieLens(root=".", train=True, download=True) test_dataset = MovieLens(root=".", train=False, download=True) loaders = { "train": DataLoader(train_dataset, batch_size=32, collate_fn=collate_fn_train), "valid": DataLoader(test_dataset, batch_size=32, collate_fn=collate_fn_valid), } item_num = len(train_dataset[0]) model = MultiVAE([200, 600, item_num], dropout=0.5) optimizer = optim.Adam(model.parameters(), lr=0.001) lr_scheduler = StepLR(optimizer, step_size=20, gamma=0.1) engine = dl.Engine() hparams = { "anneal_cap": 0.2, "total_anneal_steps": 6000, } callbacks = [ dl.NDCGCallback("logits", "targets", [20, 50, 100]), dl.MAPCallback("logits", "targets", [20, 50, 100]), dl.MRRCallback("logits", "targets", [20, 50, 100]), dl.HitrateCallback("logits", "targets", [20, 50, 100]), dl.BackwardCallback("loss"), dl.OptimizerCallback("loss", accumulation_steps=1), dl.SchedulerCallback(), ] runner = RecSysRunner()