test_loader = data.DataLoader(test_dataset, batch_size=args.test_num_ng + 1, shuffle=False, num_workers=0) # CREATE MODEL if config.model == 'NeuMF-pre': assert os.path.exists(config.GMF_model_path), 'lack of GMF model' assert os.path.exists(config.MLP_model_path), 'lack of MLP model' GMF_model = torch.load(config.GMF_model_path) MLP_model = torch.load(config.MLP_model_path) else: GMF_model = None MLP_model = None model = model.NCF(user_num, item_num, args.factor_num, args.num_layers, args.dropout, config.model, GMF_model, MLP_model) model.cuda() loss_function = nn.BCEWithLogitsLoss() if config.model == 'NeuMF-pre': optimizer = optim.SGD(model.parameters(), lr=args.lr) else: optimizer = optim.Adam(model.parameters(), lr=args.lr) # TRAINING count, best_hr = 0, 0 for epoch in range(args.epochs): model.train() # Enable dropout (if have). start_time = time.time() train_loader.dataset.ng_sample()
if args.autoscale_bsz: train_loader.autoscale_batch_size( 8192, local_bsz_bounds=(32, 512), gradient_accumulation=args.gradient_accumulation) ########################### CREATE MODEL ################################# if model_type == 'NeuMF-pre': assert os.path.exists(GMF_model_path), 'lack of GMF model' assert os.path.exists(MLP_model_path), 'lack of MLP model' GMF_model = torch.load(GMF_model_path) MLP_model = torch.load(MLP_model_path) else: GMF_model = None MLP_model = None network = model.NCF(user_num, item_num, args.factor_num, args.num_layers, args.dropout, model_type, GMF_model, MLP_model) adaptdl.torch.init_process_group("nccl" if torch.cuda.is_available() else "gloo") network.cuda() loss_function = torch.nn.BCEWithLogitsLoss() if model_type == 'NeuMF-pre': optimizer = optim.SGD(network.parameters(), lr=args.lr) else: optimizer = optim.Adam(network.parameters(), lr=args.lr) network = adl.AdaptiveDataParallel(network, optimizer, find_unused_parameters=True) ########################### TRAINING ##################################### count, best_hr = 0, 0 tensorboard_dir = os.path.join(os.getenv("ADAPTDL_TENSORBOARD_LOGDIR", "/tmp"), adaptdl.env.job_id())