def training(args): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") #===================================# #==============Logging==============# #===================================# logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) handler = TqdmLoggingHandler() handler.setFormatter( logging.Formatter(" %(asctime)s - %(message)s", "%Y-%m-%d %H:%M:%S")) logger.addHandler(handler) logger.propagate = False #===================================# #============Data Load==============# #===================================# # 1) Data open write_log(logger, "Load data...") gc.disable() with open(os.path.join(args.preprocess_path, 'processed.pkl'), 'rb') as f: data_ = pickle.load(f) train_src_indices = data_['train_src_indices'] valid_src_indices = data_['valid_src_indices'] train_trg_indices = data_['train_trg_indices'] valid_trg_indices = data_['valid_trg_indices'] src_word2id = data_['src_word2id'] trg_word2id = data_['trg_word2id'] src_vocab_num = len(src_word2id) trg_vocab_num = len(trg_word2id) del data_ gc.enable() write_log(logger, "Finished loading data!") # 2) Dataloader setting dataset_dict = { 'train': CustomDataset(train_src_indices, train_trg_indices, min_len=args.min_len, src_max_len=args.src_max_len, trg_max_len=args.trg_max_len), 'valid': CustomDataset(valid_src_indices, valid_trg_indices, min_len=args.min_len, src_max_len=args.src_max_len, trg_max_len=args.trg_max_len), } dataloader_dict = { 'train': DataLoader(dataset_dict['train'], drop_last=True, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=args.num_workers), 'valid': DataLoader(dataset_dict['valid'], drop_last=False, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=args.num_workers) } write_log( logger, f"Total number of trainingsets iterations - {len(dataset_dict['train'])}, {len(dataloader_dict['train'])}" ) #===================================# #===========Train setting===========# #===================================# # 1) Model initiating write_log(logger, 'Instantiating model...') model = Transformer( src_vocab_num=src_vocab_num, trg_vocab_num=trg_vocab_num, pad_idx=args.pad_id, bos_idx=args.bos_id, eos_idx=args.eos_id, d_model=args.d_model, d_embedding=args.d_embedding, n_head=args.n_head, dim_feedforward=args.dim_feedforward, num_common_layer=args.num_common_layer, num_encoder_layer=args.num_encoder_layer, num_decoder_layer=args.num_decoder_layer, src_max_len=args.src_max_len, trg_max_len=args.trg_max_len, dropout=args.dropout, embedding_dropout=args.embedding_dropout, trg_emb_prj_weight_sharing=args.trg_emb_prj_weight_sharing, emb_src_trg_weight_sharing=args.emb_src_trg_weight_sharing, parallel=args.parallel) model.train() model = model.to(device) tgt_mask = model.generate_square_subsequent_mask(args.trg_max_len - 1, device) # 2) Optimizer & Learning rate scheduler setting optimizer = optimizer_select(model, args) scheduler = shceduler_select(optimizer, dataloader_dict, args) scaler = GradScaler() # 3) Model resume start_epoch = 0 if args.resume: write_log(logger, 'Resume model...') checkpoint = torch.load( os.path.join(args.save_path, 'checkpoint.pth.tar')) start_epoch = checkpoint['epoch'] + 1 model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) scheduler.load_state_dict(checkpoint['scheduler']) scaler.load_state_dict(checkpoint['scaler']) del checkpoint #===================================# #=========Model Train Start=========# #===================================# best_val_acc = 0 write_log(logger, 'Traing start!') for epoch in range(start_epoch + 1, args.num_epochs + 1): start_time_e = time() for phase in ['train', 'valid']: if phase == 'train': model.train() if phase == 'valid': write_log(logger, 'Validation start...') val_loss = 0 val_acc = 0 model.eval() for i, (src, trg) in enumerate( tqdm(dataloader_dict[phase], bar_format='{l_bar}{bar:30}{r_bar}{bar:-2b}')): # Optimizer setting optimizer.zero_grad(set_to_none=True) # Input, output setting src = src.to(device, non_blocking=True) trg = trg.to(device, non_blocking=True) trg_sequences_target = trg[:, 1:] non_pad = trg_sequences_target != args.pad_id trg_sequences_target = trg_sequences_target[ non_pad].contiguous().view(-1) # Train if phase == 'train': # Loss calculate with autocast(): predicted = model(src, trg[:, :-1], tgt_mask, non_pad_position=non_pad) predicted = predicted.view(-1, predicted.size(-1)) loss = label_smoothing_loss(predicted, trg_sequences_target, args.pad_id) scaler.scale(loss).backward() scaler.unscale_(optimizer) clip_grad_norm_(model.parameters(), args.clip_grad_norm) scaler.step(optimizer) scaler.update() if args.scheduler in ['constant', 'warmup']: scheduler.step() if args.scheduler == 'reduce_train': scheduler.step(loss) # Print loss value only training if i == 0 or freq == args.print_freq or i == len( dataloader_dict['train']): acc = (predicted.max(dim=1)[1] == trg_sequences_target ).sum() / len(trg_sequences_target) iter_log = "[Epoch:%03d][%03d/%03d] train_loss:%03.3f | train_acc:%03.2f%% | learning_rate:%1.6f | spend_time:%02.2fmin" % \ (epoch, i, len(dataloader_dict['train']), loss.item(), acc*100, optimizer.param_groups[0]['lr'], (time() - start_time_e) / 60) write_log(logger, iter_log) freq = 0 freq += 1 # Validation if phase == 'valid': with torch.no_grad(): predicted = model(src, trg[:, :-1], tgt_mask, non_pad_position=non_pad) loss = F.cross_entropy(predicted, trg_sequences_target) val_loss += loss.item() val_acc += (predicted.max(dim=1)[1] == trg_sequences_target ).sum() / len(trg_sequences_target) if args.scheduler == 'reduce_valid': scheduler.step(val_loss) if args.scheduler == 'lambda': scheduler.step() if phase == 'valid': val_loss /= len(dataloader_dict[phase]) val_acc /= len(dataloader_dict[phase]) write_log(logger, 'Validation Loss: %3.3f' % val_loss) write_log(logger, 'Validation Accuracy: %3.2f%%' % (val_acc * 100)) if val_acc > best_val_acc: write_log(logger, 'Checkpoint saving...') torch.save( { 'epoch': epoch, 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), 'scaler': scaler.state_dict() }, f'checkpoint_{args.parallel}.pth.tar') best_val_acc = val_acc best_epoch = epoch else: else_log = f'Still {best_epoch} epoch accuracy({round(best_val_acc.item()*100, 2)})% is better...' write_log(logger, else_log) # 3) Print results print(f'Best Epoch: {best_epoch}') print(f'Best Accuracy: {round(best_val_acc.item(), 2)}')
def training(args): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") #===================================# #==============Logging==============# #===================================# logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) handler = TqdmLoggingHandler() handler.setFormatter( logging.Formatter(" %(asctime)s - %(message)s", "%Y-%m-%d %H:%M:%S")) logger.addHandler(handler) logger.propagate = False #===================================# #============Data Load==============# #===================================# # 1) Dataloader setting write_log(logger, "Load data...") gc.disable() dataset_dict = { 'train': CustomDataset(data_path=args.preprocessed_path, phase='train'), 'valid': CustomDataset(data_path=args.preprocessed_path, phase='valid'), 'test': CustomDataset(data_path=args.preprocessed_path, phase='test') } unique_menu_count = dataset_dict['train'].unique_count() dataloader_dict = { 'train': DataLoader(dataset_dict['train'], drop_last=True, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=args.num_workers, collate_fn=PadCollate()), 'valid': DataLoader(dataset_dict['valid'], drop_last=False, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=args.num_workers, collate_fn=PadCollate()), 'test': DataLoader(dataset_dict['test'], drop_last=False, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=args.num_workers, collate_fn=PadCollate()) } gc.enable() write_log( logger, f"Total number of trainingsets iterations - {len(dataset_dict['train'])}, {len(dataloader_dict['train'])}" ) #===================================# #===========Model setting===========# #===================================# # 1) Model initiating write_log(logger, "Instantiating models...") model = Transformer(model_type=args.model_type, input_size=unique_menu_count, d_model=args.d_model, d_embedding=args.d_embedding, n_head=args.n_head, dim_feedforward=args.dim_feedforward, num_encoder_layer=args.num_encoder_layer, dropout=args.dropout) model = model.train() model = model.to(device) # 2) Optimizer setting optimizer = optimizer_select(model, args) scheduler = shceduler_select(optimizer, dataloader_dict, args) criterion = nn.MSELoss() scaler = GradScaler(enabled=True) model, optimizer = amp.initialize(model, optimizer, opt_level='O1') # 2) Model resume start_epoch = 0 if args.resume: checkpoint = torch.load(os.path.join(args.model_path, 'checkpoint.pth.tar'), map_location='cpu') start_epoch = checkpoint['epoch'] + 1 model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) scheduler.load_state_dict(checkpoint['scheduler']) model = model.train() model = model.to(device) del checkpoint #===================================# #=========Model Train Start=========# #===================================# best_val_rmse = 9999999 write_log(logger, 'Train start!') for epoch in range(start_epoch, args.num_epochs): for phase in ['train', 'valid']: if phase == 'train': model.train() train_start_time = time.time() freq = 0 elif phase == 'valid': model.eval() val_loss = 0 val_rmse = 0 for i, (src_menu, label_lunch, label_supper) in enumerate(dataloader_dict[phase]): # Optimizer setting optimizer.zero_grad() # Input, output setting src_menu = src_menu.to(device, non_blocking=True) label_lunch = label_lunch.float().to(device, non_blocking=True) label_supper = label_supper.float().to(device, non_blocking=True) # Model with torch.set_grad_enabled(phase == 'train'): with autocast(enabled=True): if args.model_type == 'sep': logit = model(src_menu) logit_lunch = logit[:, 0] logit_supper = logit[:, 0] elif args.model_type == 'total': logit = model(src_menu) logit_lunch = logit[:, 0] logit_supper = logit[:, 1] # Loss calculate loss_lunch = criterion(logit_lunch, label_lunch) loss_supper = criterion(logit_supper, label_supper) loss = loss_lunch + loss_supper # Back-propagation if phase == 'train': scaler.scale(loss).backward() scaler.unscale_(optimizer) clip_grad_norm_(model.parameters(), args.clip_grad_norm) scaler.step(optimizer) scaler.update() # Scheduler setting if args.scheduler in ['constant', 'warmup']: scheduler.step() if args.scheduler == 'reduce_train': scheduler.step(loss) # Print loss value rmse_loss = torch.sqrt(loss) if phase == 'train': if i == 0 or freq == args.print_freq or i == len( dataloader_dict['train']): batch_log = "[Epoch:%d][%d/%d] train_MSE_loss:%2.3f | train_RMSE_loss:%2.3f | learning_rate:%3.6f | spend_time:%3.2fmin" \ % (epoch+1, i, len(dataloader_dict['train']), loss.item(), rmse_loss.item(), optimizer.param_groups[0]['lr'], (time.time() - train_start_time) / 60) write_log(logger, batch_log) freq = 0 freq += 1 elif phase == 'valid': val_loss += loss.item() val_rmse += rmse_loss.item() if phase == 'valid': val_loss /= len(dataloader_dict['valid']) val_rmse /= len(dataloader_dict['valid']) write_log(logger, 'Validation Loss: %3.3f' % val_loss) write_log(logger, 'Validation RMSE: %3.3f' % val_rmse) if val_rmse < best_val_rmse: write_log(logger, 'Checkpoint saving...') if not os.path.exists(args.save_path): os.mkdir(args.save_path) torch.save( { 'epoch': epoch, 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), 'scaler': scaler.state_dict() }, os.path.join(args.save_path, f'checkpoint_cap.pth.tar')) best_val_rmse = val_rmse best_epoch = epoch else: else_log = f'Still {best_epoch} epoch RMSE({round(best_val_rmse, 3)}) is better...' write_log(logger, else_log) # 3) write_log(logger, f'Best Epoch: {best_epoch+1}') write_log(logger, f'Best Accuracy: {round(best_val_rmse, 3)}')
for n, p in model.named_parameters(): if p.dim() > 1 and (n != "embedding.lut.weight" and config.pretrain_emb): xavier_uniform_(p) print("MODEL USED", config.model) print("TRAINABLE PARAMETERS", count_parameters(model)) check_iter = 2000 try: if (config.USE_CUDA): model.cuda() model = model.train() best_ppl = 1000 patient = 0 writer = SummaryWriter(log_dir=config.save_path) weights_best = deepcopy(model.state_dict()) data_iter = make_infinite(data_loader_tra) for n_iter in tqdm(range(1000000)): loss, ppl, bce, acc = model.train_one_batch(next(data_iter), n_iter) writer.add_scalars('loss', {'loss_train': loss}, n_iter) writer.add_scalars('ppl', {'ppl_train': ppl}, n_iter) writer.add_scalars('bce', {'bce_train': bce}, n_iter) writer.add_scalars('accuracy', {'acc_train': acc}, n_iter) if (config.noam): writer.add_scalars('lr', {'learning_rata': model.optimizer._rate}, n_iter) if ((n_iter + 1) % check_iter == 0): model = model.eval() model.epoch = n_iter model.__id__logger = 0
_, _, _ = model.train_one_batch(d) generate(model, val_iter, persona) p = Personas() # Build model, optimizer, and set states print("Test model", config.model) model = Transformer(p.vocab, model_file_path=config.save_path, is_eval=False) # get persona map filename = 'data/ConvAI2/test_persona_map' with open(filename, 'rb') as f: persona_map = pickle.load(f) #generate iterations = 11 weights_original = deepcopy(model.state_dict()) tasks = p.get_personas('test') for per in tqdm(tasks): num_of_dialog = p.get_num_of_dialog(persona=per, split='test') for val_dial_index in range(num_of_dialog): train_iter, val_iter = p.get_data_loader(persona=per, batch_size=config.batch_size, split='test', fold=val_dial_index) persona = [] for ppp in persona_map[per]: persona += ppp persona = list(set(persona)) do_learning(model, train_iter, val_iter,
last_epoch = int(weights[-1].split('_')[-1]) weight_path = weights[-1].replace('\\', '/') print('weight info of last epoch', weight_path) model.load_state_dict(torch.load(weight_path)) total_epoch = last_epoch + plus_epoch else: last_epoch = 0 total_epoch = plus_epoch model.train() for epoch in range(plus_epoch): epoch_loss = 0 for iteration, data in enumerate(data_loader): encoder_inputs, decoder_inputs, targets = data optimizer.zero_grad() logits, _ = model(encoder_inputs, decoder_inputs) logits = logits.contiguous().view(-1, trg_vocab_size) targets = targets.contiguous().view(-1) loss = criterion(logits, targets) loss.backward() nn.utils.clip_grad_norm_(model.parameters(), 0.5) optimizer.step() epoch_loss += loss if (iteration + 1) % 500 == 0: print('Epoch: %3d\t' % (last_epoch + epoch + 1), 'Iteration: %3d \t' % (iteration + 1), 'Cost: {:.5f}'.format(epoch_loss / (iteration + 1))) scheduler.step(epoch_loss) model_path = './weight/transformer_%d' % total_epoch torch.save(model.state_dict(), model_path)
meta_batch_size = config.meta_batch_size tasks = p.get_personas('train') #tasks_loader = {t: p.get_data_loader(persona=t,batch_size=config.batch_size, split='train') for t in tasks} tasks_iter = make_infinite_list(tasks) # meta early stop patience = 50 if config.fix_dialnum_train: patience = 100 best_loss = 10000000 stop_count = 0 # Main loop for meta_iteration in range(config.epochs): ## save original weights to make the update weights_original = deepcopy(meta_net.state_dict()) train_loss_before = [] train_loss_meta = [] #loss accumulate from a batch of tasks batch_loss = 0 for _ in range(meta_batch_size): # Get task if config.fix_dialnum_train: train_iter, val_iter = p.get_balanced_loader( persona=tasks_iter.__next__(), batch_size=config.batch_size, split='train') else: train_iter, val_iter = p.get_data_loader( persona=tasks_iter.__next__(), batch_size=config.batch_size,
class Trainer: def __init__(self, params, mode, train_iter=None, valid_iter=None, test_iter=None): self.params = params # Train mode if mode == 'train': self.train_iter = train_iter self.valid_iter = valid_iter # Test mode else: self.test_iter = test_iter self.model = Transformer(self.params) self.model.to(self.params.device) # Scheduling Optimzer self.optimizer = ScheduledAdam(optim.Adam(self.model.parameters(), betas=(0.9, 0.98), eps=1e-9), hidden_dim=params.hidden_dim, warm_steps=params.warm_steps) self.criterion = nn.CrossEntropyLoss(ignore_index=self.params.pad_idx) self.criterion.to(self.params.device) def train(self): print(self.model) print( f'The model has {self.model.count_params():,} trainable parameters' ) best_valid_loss = float('inf') for epoch in range(self.params.num_epoch): self.model.train() epoch_loss = 0 start_time = time.time() for batch in self.train_iter: # For each batch, first zero the gradients self.optimizer.zero_grad() source = batch.kor target = batch.eng # target sentence consists of <sos> and following tokens (except the <eos> token) output = self.model(source, target[:, :-1])[0] # ground truth sentence consists of tokens and <eos> token (except the <sos> token) output = output.contiguous().view(-1, output.shape[-1]) target = target[:, 1:].contiguous().view(-1) # output = [(batch size * target length - 1), output dim] # target = [(batch size * target length - 1)] loss = self.criterion(output, target) loss.backward() # clip the gradients to prevent the model from exploding gradient torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.params.clip) self.optimizer.step() # 'item' method is used to extract a scalar from a tensor which only contains a single value. epoch_loss += loss.item() train_loss = epoch_loss / len(self.train_iter) valid_loss = self.evaluate() end_time = time.time() epoch_mins, epoch_secs = epoch_time(start_time, end_time) if valid_loss < best_valid_loss: best_valid_loss = valid_loss torch.save(self.model.state_dict(), self.params.save_model) print( f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s' ) print( f'\tTrain Loss: {train_loss:.3f} | Val. Loss: {valid_loss:.3f}' ) def evaluate(self): self.model.eval() epoch_loss = 0 with torch.no_grad(): for batch in self.valid_iter: source = batch.kor target = batch.eng output = self.model(source, target[:, :-1])[0] output = output.contiguous().view(-1, output.shape[-1]) target = target[:, 1:].contiguous().view(-1) loss = self.criterion(output, target) epoch_loss += loss.item() return epoch_loss / len(self.valid_iter) def inference(self): self.model.load_state_dict(torch.load(self.params.save_model)) self.model.eval() epoch_loss = 0 with torch.no_grad(): for batch in self.test_iter: source = batch.kor target = batch.eng output = self.model(source, target[:, :-1])[0] output = output.contiguous().view(-1, output.shape[-1]) target = target[:, 1:].contiguous().view(-1) loss = self.criterion(output, target) epoch_loss += loss.item() test_loss = epoch_loss / len(self.test_iter) print(f'Test Loss: {test_loss:.3f}')