def train(opt, AMP, WdB, train_data_path, train_data_list, test_data_path, test_data_list, experiment_name, train_batch_size, val_batch_size, workers, lr, valInterval, num_iter, wdbprj, continue_model=''): HVD3P = pO.HVD or pO.DDP torch.cuda.set_device(-1) val_batch_size = 1 if OnceExecWorker and WdB: wandb.init(project=wdbprj, name=experiment_name) wandb.config.update(opt) train_dataset = ds_load.myLoadDS(train_data_list, train_data_path) if opt.num_gpu > 1: workers = workers * ( 1 if HVD3P else opt.num_gpu ) model = OrigamiNet() model.apply(init_bn) if not pO.DDP: model = model.to(device) else: model.cuda(opt.rank) optimizer = optim.Adam(model.parameters(), lr=lr) lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=10**(-1/90000)) if OnceExecWorker and WdB: wandb.watch(model, log="all") if pO.HVD: hvd.broadcast_parameters(model.state_dict(), root_rank=0) hvd.broadcast_optimizer_state(optimizer, root_rank=0) optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters()) if pO.DDP and opt.rank!=0: random.seed() np.random.seed() if AMP: model, optimizer = amp.initialize(model, optimizer, opt_level = "O1") if pO.DP: model = torch.nn.DataParallel(model) elif pO.DDP: model = pDDP(model, device_ids=[opt.rank], output_device=opt.rank,find_unused_parameters=False) model_ema = ModelEma(model) if continue_model != '': checkpoint = torch.load(continue_model, map_location=f'cuda:{opt.rank}' if HVD3P else None) model.load_state_dict(checkpoint['model'], strict=True) optimizer.load_state_dict(checkpoint['optimizer']) model_ema._load_checkpoint(continue_model, f'cuda:{opt.rank}' if HVD3P else None) criterion = torch.nn.CTCLoss(reduction='none', zero_infinity=True).to(device) converter = CTCLabelConverter(train_dataset.ralph.values()) return model, converter
def predict(self, model, converter, filename): dataset = ds_load.myLoadDS('','',single=True,lst=filename) dataloader = torch.utils.data.DataLoader(dataset, batch_size=1 , pin_memory=True, num_workers = 16, sampler=None) d = iter(dataloader) model.zero_grad() image_tensors, labels = next(d) image = image_tensors.to(device) batch_size = 1 preds = model(image,'') preds_size = torch.IntTensor([preds.size(1)] * batch_size).to(device) preds = preds.permute(1, 0, 2).log_softmax(2) _, preds_index = preds.max(2) preds_index = preds_index.transpose(1, 0).contiguous().view(-1) result = converter.decode(preds_index.data, preds_size.data) return result
def train(opt, AMP, WdB, ralph_path, train_data_path, train_data_list, test_data_path, test_data_list, experiment_name, train_batch_size, val_batch_size, workers, lr, valInterval, num_iter, wdbprj, continue_model='', finetune=''): HVD3P = pO.HVD or pO.DDP os.makedirs(f'./saved_models/{experiment_name}', exist_ok=True) # if OnceExecWorker and WdB: # wandb.init(project=wdbprj, name=experiment_name) # wandb.config.update(opt) # load supplied ralph with open(ralph_path, 'r') as f: ralph_train = json.load(f) print('[4] IN TRAIN; BEFORE MAKING DATASET') train_dataset = ds_load.myLoadDS(train_data_list, train_data_path, ralph=ralph_train) valid_dataset = ds_load.myLoadDS(test_data_list, test_data_path, ralph=ralph_train) # SAVE RALPH FOR LATER USE # with open(f'./saved_models/{experiment_name}/ralph.json', 'w+') as f: # json.dump(train_dataset.ralph, f) print('[5] DATASET DONE LOADING') if OnceExecWorker: print(pO) print('Alphabet :', len(train_dataset.alph), train_dataset.alph) for d in [train_dataset, valid_dataset]: print('Dataset Size :', len(d.fns)) # print('Max LbW : ',max(list(map(len,d.tlbls))) ) # print('#Chars : ',sum([len(x) for x in d.tlbls])) # print('Sample label :',d.tlbls[-1]) # print("Dataset :", sorted(list(map(len,d.tlbls))) ) print('-' * 80) if opt.num_gpu > 1: workers = workers * (1 if HVD3P else opt.num_gpu) if HVD3P: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=opt.world_size, rank=opt.rank) valid_sampler = torch.utils.data.distributed.DistributedSampler( valid_dataset, num_replicas=opt.world_size, rank=opt.rank) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=train_batch_size, shuffle=True if not HVD3P else False, pin_memory=True, num_workers=int(workers), sampler=train_sampler if HVD3P else None, worker_init_fn=WrkSeeder, collate_fn=ds_load.SameTrCollate) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=val_batch_size, pin_memory=True, num_workers=int(workers), sampler=valid_sampler if HVD3P else None) model = OrigamiNet() model.apply(init_bn) # load finetune ckpt if finetune != '': model = load_finetune(model, finetune) model.train() if OnceExecWorker: import pprint [print(k, model.lreszs[k]) for k in sorted(model.lreszs.keys())] biparams = list( dict(filter(lambda kv: 'bias' in kv[0], model.named_parameters())).values()) nonbiparams = list( dict(filter(lambda kv: 'bias' not in kv[0], model.named_parameters())).values()) if not pO.DDP: model = model.to(device) else: model.cuda(opt.rank) optimizer = optim.Adam(model.parameters(), lr=lr) lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=10**(-1 / 90000)) # if OnceExecWorker and WdB: # wandb.watch(model, log="all") if pO.HVD: hvd.broadcast_parameters(model.state_dict(), root_rank=0) hvd.broadcast_optimizer_state(optimizer, root_rank=0) optimizer = hvd.DistributedOptimizer( optimizer, named_parameters=model.named_parameters()) # optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters(), compression=hvd.Compression.fp16) if pO.DDP and opt.rank != 0: random.seed() np.random.seed() # if AMP: # model, optimizer = amp.initialize(model, optimizer, opt_level = "O1") if pO.DP: model = torch.nn.DataParallel(model) elif pO.DDP: model = pDDP(model, device_ids=[opt.rank], output_device=opt.rank, find_unused_parameters=False) model_ema = ModelEma(model) if continue_model != '': if OnceExecWorker: print(f'loading pretrained model from {continue_model}') checkpoint = torch.load( continue_model, map_location=f'cuda:{opt.rank}' if HVD3P else None) model.load_state_dict(checkpoint['model'], strict=True) optimizer.load_state_dict(checkpoint['optimizer']) model_ema._load_checkpoint(continue_model, f'cuda:{opt.rank}' if HVD3P else None) criterion = torch.nn.CTCLoss(reduction='none', zero_infinity=True).to(device) converter = CTCLabelConverter(train_dataset.ralph.values()) if OnceExecWorker: with open(f'./saved_models/{experiment_name}/opt.txt', 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' opt_log += gin.operative_config_str() opt_file.write(opt_log) # if WdB: # wandb.config.gin_str = gin.operative_config_str().splitlines() print(optimizer) print(opt_log) start_time = time.time() best_accuracy = -1 best_norm_ED = 1e+6 best_CER = 1e+6 i = 0 gAcc = 1 epoch = 1 btReplay = False and AMP max_batch_replays = 3 if HVD3P: train_sampler.set_epoch(epoch) titer = iter(train_loader) while (True): start_time = time.time() model.zero_grad() train_loss = Metric(pO, 'train_loss') for j in trange(valInterval, leave=False, desc='Training'): # Load a batch try: image_tensors, labels, fnames = next(titer) except StopIteration: epoch += 1 if HVD3P: train_sampler.set_epoch(epoch) titer = iter(train_loader) image_tensors, labels, fnames = next(titer) # log filenames # fnames = [f'{i}___{fname}' for fname in fnames] # with open(f'./saved_models/{experiment_name}/filelog.txt', 'a+') as f: # f.write('\n'.join(fnames) + '\n') # Move to device image = image_tensors.to(device) text, length = converter.encode(labels) batch_size = image.size(0) replay_batch = True maxR = 3 while replay_batch and maxR > 0: maxR -= 1 # Forward pass preds = model(image, text).float() preds_size = torch.IntTensor([preds.size(1)] * batch_size).to(device) preds = preds.permute(1, 0, 2).log_softmax(2) if i == 0 and OnceExecWorker: print('Model inp : ', image.dtype, image.size()) print('CTC inp : ', preds.dtype, preds.size(), preds_size[0]) # To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss torch.backends.cudnn.enabled = False cost = criterion(preds, text.to(device), preds_size, length.to(device)).mean() / gAcc torch.backends.cudnn.enabled = True train_loss.update(cost) # cost tracking? # with open(f'./saved_models/{experiment_name}/steplog.txt', 'a+') as f: # f.write(f'Step {i} cost: {cost}\n') optimizer.zero_grad() default_optimizer_step = optimizer.step # added for batch replay # Backward and step if not AMP: cost.backward() replay_batch = False else: # with amp.scale_loss(cost, optimizer) as scaled_loss: # scaled_loss.backward() # if pO.HVD: optimizer.synchronize() # if optimizer.step is default_optimizer_step or not btReplay: # replay_batch = False # elif maxR>0: # optimizer.step() pass if btReplay: pass #amp._amp_state.loss_scalers[0]._loss_scale = mx_sc if (i + 1) % gAcc == 0: if pO.HVD and AMP: with optimizer.skip_synchronize(): optimizer.step() else: optimizer.step() model.zero_grad() model_ema.update(model, num_updates=i / 2) if (i + 1) % (gAcc * 2) == 0: lr_scheduler.step() i += 1 # validation part if True: elapsed_time = time.time() - start_time start_time = time.time() model.eval() with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, ted, bleu, preds, labels, infer_time = validation( model_ema.ema, criterion, valid_loader, converter, opt, pO) model.train() v_time = time.time() - start_time if OnceExecWorker: if current_norm_ED < best_norm_ED: best_norm_ED = current_norm_ED checkpoint = { 'model': model.state_dict(), 'state_dict_ema': model_ema.ema.state_dict(), 'optimizer': optimizer.state_dict(), } torch.save( checkpoint, f'./saved_models/{experiment_name}/best_norm_ED.pth') if ted < best_CER: best_CER = ted if current_accuracy > best_accuracy: best_accuracy = current_accuracy out = f'[{i}] Loss: {train_loss.avg:0.5f} time: ({elapsed_time:0.1f},{v_time:0.1f})' out += f' vloss: {valid_loss:0.3f}' out += f' CER: {ted:0.4f} NER: {current_norm_ED:0.4f} lr: {lr_scheduler.get_lr()[0]:0.5f}' out += f' bAcc: {best_accuracy:0.1f}, bNER: {best_norm_ED:0.4f}, bCER: {best_CER:0.4f}, B: {bleu*100:0.2f}' print(out) with open(f'./saved_models/{experiment_name}/log_train.txt', 'a') as log: log.write(out + '\n') # if WdB: # wandb.log({'lr': lr_scheduler.get_lr()[0], 'It':i, 'nED': current_norm_ED, 'B':bleu*100, # 'tloss':train_loss.avg, 'AnED': best_norm_ED, 'CER':ted, 'bestCER':best_CER, 'vloss':valid_loss}) if DEBUG: print( f'[!!!] Iteration check. Value of i: {i} | Value of num_iter: {num_iter}' ) # Change i == num_iter to i >= num_iter # Add num_iter > 0 condition if num_iter > 0 and i >= num_iter: print('end the training') #sys.exit() break