class Triple2vec_train(TrainEngine): """ An instance class from the TrainEngine base class """ def __init__(self, config): """Constructor Args: config (dict): All the parameters for the model """ self.config = config super(Triple2vec_train, self).__init__(self.config) self.gpu_id, self.config["device_str"] = self.get_device() def load_dataset(self): """Load dataset.""" split_data = load_split_dataset(self.config) self.data = GroceryData(split_dataset=split_data, config=self.config) self.config["model"]["n_users"] = self.data.n_users self.config["model"]["n_items"] = self.data.n_items def train(self): self.load_dataset() self.engine = Triple2vecEngine(self.config) self.engine.data = self.data self.train_data = self.data.sample_triple() train_loader = DataLoader( torch.LongTensor(self.train_data.to_numpy()).to(self.engine.device), batch_size=self.config["model"]["batch_size"], shuffle=True, drop_last=True, ) self.monitor = Monitor( log_dir=self.config["system"]["run_dir"], delay=1, gpu_id=self.gpu_id ) self.model_save_dir = os.path.join( self.config["system"]["model_save_dir"], self.config["model"]["save_name"] ) self._train(self.engine, train_loader, self.model_save_dir) self.config["run_time"] = self.monitor.stop() return self.eval_engine.best_valid_performance
def load_dataset(self): """Load dataset.""" split_data = load_split_dataset(self.config) self.data = GroceryData(split_dataset=split_data, config=self.config) self.config["model"]["n_users"] = self.data.n_users self.config["model"]["n_items"] = self.data.n_items
class VBCAR_train(TrainEngine): """ An instance class from the TrainEngine base class """ def __init__(self, config): """Constructor Args: config (dict): All the parameters for the model """ self.config = config super(VBCAR_train, self).__init__(self.config) def load_dataset(self): """Load dataset.""" split_data = load_split_dataset(self.config) self.data = GroceryData(split_dataset=split_data, config=self.config) self.config["model"]["n_users"] = self.data.n_users self.config["model"]["n_items"] = self.data.n_items def train(self): """Default train implementation """ self.load_dataset() self.train_data = self.data.sample_triple() self.config["model"]["alpha_step"] = ( 1 - self.config["model"]["alpha"]) / ( self.config["model"]["max_epoch"]) self.config["user_fea"] = self.data.user_feature self.config["item_fea"] = self.data.item_feature self.engine = VBCAREngine(self.config) self.engine.data = self.data assert hasattr(self, "engine"), "Please specify the exact model engine !" self.monitor = Monitor(log_dir=self.config["system"]["run_dir"], delay=1, gpu_id=self.gpu_id) print("Start training... ") epoch_bar = tqdm(range(self.config["model"]["max_epoch"]), file=sys.stdout) for epoch in epoch_bar: print(f"Epoch {epoch} starts !") print("-" * 80) if epoch > 0 and self.eval_engine.n_no_update == 0: # previous epoch have already obtained better result self.engine.save_checkpoint(model_dir=os.path.join( self.config["system"]["model_save_dir"], "model.cpk")) if self.eval_engine.n_no_update >= MAX_N_UPDATE: print( "Early stop criterion triggered, no performance update for {:} times" .format(MAX_N_UPDATE)) break data_loader = DataLoader( torch.LongTensor(self.train_data.to_numpy()).to( self.engine.device), batch_size=self.config["model"]["batch_size"], shuffle=True, drop_last=True, ) self.engine.train_an_epoch(data_loader, epoch_id=epoch) self.eval_engine.train_eval(self.data.valid[0], self.data.test[0], self.engine.model, epoch) # anneal alpha self.engine.model.alpha = min( self.config["model"]["alpha"] + math.exp(epoch - self.config["model"]["max_epoch"] + 20), 1, ) """Sets the learning rate to the initial LR decayed by 10 every 10 epochs""" lr = self.config["model"]["lr"] * (0.5**(epoch // 10)) for param_group in self.engine.optimizer.param_groups: param_group["lr"] = lr self.config["run_time"] = self.monitor.stop() return self.eval_engine.best_valid_performance