def train_model(args): print(args) print("generating config") config = Config( input_dim=args.input_dim, dropout=args.dropout, highway=args.highway, nn_layers=args.nn_layers, ) model_name = ".".join( (args.model_file, str(args.rl_baseline_method), args.sampling_method, "gamma", str(args.gamma), "beta", str(args.beta), "batch", str(args.train_batch), "learning_rate", str(args.lr) + "-" + str(args.lr_sch), "bsz", str(args.batch_size), "data", args.data_dir.split('/')[0], args.eval_data, "input_dim", str(config.input_dim), "max_num", str(args.max_num_of_ans), "reward", str(args.reward_type), "dropout", str(args.dropout) + "-" + str(args.clip_grad), "highway", str(args.highway), "nn-" + str(args.nn_layers), 'ans')) log_name = ".".join( ("log_bert/model", str(args.rl_baseline_method), args.sampling_method, "gamma", str(args.gamma), "beta", str(args.beta), "batch", str(args.train_batch), "lr", str(args.lr) + "-" + str(args.lr_sch), "bsz", str(args.batch_size), "data", args.data_dir.split('/')[0], args.eval_data, "input_dim", str(config.input_dim), "max_num", str(args.max_num_of_ans), "reward", str(args.reward_type), "dropout", str(args.dropout) + "-" + str(args.clip_grad), "highway", str(args.highway), "nn-" + str(args.nn_layers), 'ans')) print("initialising data loader and RL learner") data_loader = PickleReader(args.data_dir) data = args.data_dir.split('/')[0] num_data = 0 if data == "wiki_qa": num_data = 873 elif data == "trec_qa": num_data = 1229 else: assert (1 == 2) # init statistics reward_list = [] loss_list = [] best_eval_reward = 0. model_save_name = model_name bandit = ContextualBandit(b=args.batch_size, rl_baseline_method=args.rl_baseline_method, sample_method=args.sampling_method) print("Loaded the Bandit") bert_cb = model2.BERT_CB(config) print("Loaded the model") bert_cb.cuda() vocab = "vocab" if args.load_ext: model_name = args.model_file print("loading existing model%s" % model_name) bert_cb = torch.load(model_name, map_location=lambda storage, loc: storage) bert_cb.cuda() model_save_name = model_name log_name = "/".join(("log_bert", model_name.split("/")[1])) print("finish loading and evaluate model %s" % model_name) # evaluate.ext_model_eval(extract_net, vocab, args, eval_data="test") best_eval_reward = evaluate.ext_model_eval(bert_cb, vocab, args, args.eval_data)[0] logging.basicConfig(filename='%s.log' % log_name, level=logging.DEBUG, format='%(asctime)s %(levelname)-10s %(message)s') # Loss and Optimizer optimizer_ans = torch.optim.Adam([ param for param in bert_cb.parameters() if param.requires_grad == True ], lr=args.lr, betas=(args.beta, 0.999), weight_decay=1e-6) if args.lr_sch == 1: scheduler = ReduceLROnPlateau(optimizer_ans, 'max', verbose=1, factor=0.9, patience=3, cooldown=3, min_lr=9e-5, epsilon=1e-6) if best_eval_reward: scheduler.step(best_eval_reward, 0) print("init_scheduler") elif args.lr_sch == 2: scheduler = torch.optim.lr_scheduler.CyclicLR( optimizer_ans, args.lr, args.lr_2, step_size_up=3 * int(num_data / args.train_batch), step_size_down=3 * int(num_data / args.train_batch), mode='exp_range', gamma=0.98, cycle_momentum=False) print("starting training") start_time = time.time() n_step = 100 gamma = args.gamma #vocab = "vocab" if num_data < 2000: n_val = int(num_data / (5 * args.train_batch)) else: n_val = int(num_data / (7 * args.train_batch)) with torch.autograd.set_detect_anomaly(True): for epoch in tqdm(range(args.epochs_ext), desc="epoch:"): train_iter = data_loader.chunked_data_reader( "train", data_quota=args.train_example_quota) #-1 step_in_epoch = 0 for dataset in train_iter: for step, contexts in tqdm( enumerate( BatchDataLoader(dataset, batch_size=args.train_batch, shuffle=True))): try: bert_cb.train() step_in_epoch += 1 loss = 0. reward = 0. for context in contexts: # q_a = torch.autograd.Variable(torch.from_numpy(context.features)).cuda() pre_processed, a_len, sorted_id = model2.bert_preprocess( context.answers) q_a = torch.autograd.Variable( pre_processed.type(torch.float)) a_len = torch.autograd.Variable(a_len) outputs = bert_cb(q_a, a_len) context.labels = np.array( context.labels)[sorted_id] if args.prt_inf and np.random.randint(0, 100) == 0: prt = True else: prt = False loss_t, reward_t = bandit.train( outputs, context, max_num_of_ans=args.max_num_of_ans, reward_type=args.reward_type, prt=prt) #print(str(loss_t)+' '+str(len(a_len))) # loss_t = loss_t.view(-1) true_labels = np.zeros(len(context.labels)) gold_labels = np.array(context.labels) true_labels[gold_labels > 0] = 1.0 # ml_loss = F.binary_cross_entropy(outputs.view(-1),torch.tensor(true_labels).type(torch.float).cuda()) ml_loss = F.binary_cross_entropy( outputs.view(-1), torch.tensor(true_labels).type( torch.float).cuda()) loss_e = ((gamma * loss_t) + ((1 - gamma) * ml_loss)) loss_e.backward() loss += loss_e.item() reward += reward_t loss = loss / args.train_batch reward = reward / args.train_batch if prt: print('Probabilities: ', outputs.squeeze().data.cpu().numpy()) print('-' * 80) reward_list.append(reward) loss_list.append(loss) #if isinstance(loss, Variable): # loss.backward() if step % 1 == 0: if args.clip_grad: torch.nn.utils.clip_grad_norm_( bert_cb.parameters(), args.clip_grad) # gradient clipping optimizer_ans.step() optimizer_ans.zero_grad() if args.lr_sch == 2: scheduler.step() logging.info('Epoch %d Step %d Reward %.4f Loss %.4f' % (epoch, step_in_epoch, reward, loss)) except Exception as e: print(e) #print(loss) #print(loss_e) traceback.print_exc() if (step_in_epoch) % n_step == 0 and step_in_epoch != 0: logging.info('Epoch ' + str(epoch) + ' Step ' + str(step_in_epoch) + ' reward: ' + str(np.mean(reward_list)) + ' loss: ' + str(np.mean(loss_list))) reward_list = [] loss_list = [] if (step_in_epoch) % n_val == 0 and step_in_epoch != 0: print("doing evaluation") bert_cb.eval() eval_reward = evaluate.ext_model_eval( bert_cb, vocab, args, args.eval_data) if eval_reward[0] > best_eval_reward: best_eval_reward = eval_reward[0] print( "saving model %s with eval_reward:" % model_save_name, eval_reward) logging.debug("saving model" + str(model_save_name) + "with eval_reward:" + str(eval_reward)) torch.save(bert_cb, model_name) print('epoch ' + str(epoch) + ' reward in validation: ' + str(eval_reward)) logging.debug('epoch ' + str(epoch) + ' reward in validation: ' + str(eval_reward)) logging.debug('time elapsed:' + str(time.time() - start_time)) if args.lr_sch == 1: bert_cb.eval() eval_reward = evaluate.ext_model_eval(bert_cb, vocab, args, args.eval_data) scheduler.step(eval_reward[0], epoch) return bert_cb
def main(): args = parser.parse_args() if args.output: output_base = args.output else: output_base = './output' exp_name = '-'.join([ datetime.now().strftime("%Y%m%d-%H%M%S"), args.model, args.gp, 'f'+str(args.fold)]) output_dir = get_outdir(output_base, 'train', exp_name) train_input_root = os.path.join(args.data) batch_size = args.batch_size num_epochs = args.epochs wav_size = (16000,) num_classes = len(dataset.get_labels()) torch.manual_seed(args.seed) model = model_factory.create_model( args.model, in_chs=1, pretrained=args.pretrained, num_classes=num_classes, drop_rate=args.drop, global_pool=args.gp, checkpoint_path=args.initial_checkpoint) #model.reset_classifier(num_classes=num_classes) dataset_train = dataset.CommandsDataset( root=train_input_root, mode='train', fold=args.fold, wav_size=wav_size, format='spectrogram', ) loader_train = data.DataLoader( dataset_train, batch_size=batch_size, pin_memory=True, shuffle=True, num_workers=args.workers ) dataset_eval = dataset.CommandsDataset( root=train_input_root, mode='validate', fold=args.fold, wav_size=wav_size, format='spectrogram', ) loader_eval = data.DataLoader( dataset_eval, batch_size=args.batch_size, pin_memory=True, shuffle=False, num_workers=args.workers ) train_loss_fn = validate_loss_fn = torch.nn.CrossEntropyLoss() train_loss_fn = train_loss_fn.cuda() validate_loss_fn = validate_loss_fn.cuda() opt_params = list(model.parameters()) if args.opt.lower() == 'sgd': optimizer = optim.SGD( opt_params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) elif args.opt.lower() == 'adam': optimizer = optim.Adam( opt_params, lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps) elif args.opt.lower() == 'nadam': optimizer = nadam.Nadam( opt_params, lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps) elif args.opt.lower() == 'adadelta': optimizer = optim.Adadelta( opt_params, lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps) elif args.opt.lower() == 'rmsprop': optimizer = optim.RMSprop( opt_params, lr=args.lr, alpha=0.9, eps=args.opt_eps, momentum=args.momentum, weight_decay=args.weight_decay) else: assert False and "Invalid optimizer" del opt_params if not args.decay_epochs: print('No decay epoch set, using plateau scheduler.') lr_scheduler = ReduceLROnPlateau(optimizer, patience=10) else: lr_scheduler = None # optionally resume from a checkpoint start_epoch = 0 if args.start_epoch is None else args.start_epoch if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: if 'args' in checkpoint: print(checkpoint['args']) new_state_dict = OrderedDict() for k, v in checkpoint['state_dict'].items(): if k.startswith('module'): name = k[7:] # remove `module.` else: name = k new_state_dict[name] = v model.load_state_dict(new_state_dict) if 'optimizer' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer']) if 'loss' in checkpoint: train_loss_fn.load_state_dict(checkpoint['loss']) print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch'])) start_epoch = checkpoint['epoch'] if args.start_epoch is None else args.start_epoch else: model.load_state_dict(checkpoint) else: print("=> no checkpoint found at '{}'".format(args.resume)) exit(1) saver = CheckpointSaver(checkpoint_dir=output_dir) if args.num_gpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() else: model.cuda() # Optional fine-tune of only the final classifier weights for specified number of epochs (or part of) if not args.resume and args.ft_epochs > 0.: if isinstance(model, torch.nn.DataParallel): classifier_params = model.module.get_classifier().parameters() else: classifier_params = model.get_classifier().parameters() if args.opt.lower() == 'adam': finetune_optimizer = optim.Adam( classifier_params, lr=args.ft_lr, weight_decay=args.weight_decay) else: finetune_optimizer = optim.SGD( classifier_params, lr=args.ft_lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) finetune_epochs_int = int(np.ceil(args.ft_epochs)) finetune_final_batches = int(np.ceil((1 - (finetune_epochs_int - args.ft_epochs)) * len(loader_train))) print(finetune_epochs_int, finetune_final_batches) for fepoch in range(0, finetune_epochs_int): if fepoch == finetune_epochs_int - 1 and finetune_final_batches: batch_limit = finetune_final_batches else: batch_limit = 0 train_epoch( fepoch, model, loader_train, finetune_optimizer, train_loss_fn, args, output_dir=output_dir, batch_limit=batch_limit) best_loss = None try: for epoch in range(start_epoch, num_epochs): if args.decay_epochs: adjust_learning_rate( optimizer, epoch, initial_lr=args.lr, decay_rate=args.decay_rate, decay_epochs=args.decay_epochs) train_metrics = train_epoch( epoch, model, loader_train, optimizer, train_loss_fn, args, saver=saver, output_dir=output_dir) # save a recovery in case validation blows up saver.save_recovery({ 'epoch': epoch + 1, 'arch': args.model, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'loss': train_loss_fn.state_dict(), 'args': args, 'gp': args.gp, }, epoch=epoch + 1, batch_idx=0) step = epoch * len(loader_train) eval_metrics = validate( step, model, loader_eval, validate_loss_fn, args, output_dir=output_dir) if lr_scheduler is not None: lr_scheduler.step(eval_metrics['eval_loss']) rowd = OrderedDict(epoch=epoch) rowd.update(train_metrics) rowd.update(eval_metrics) with open(os.path.join(output_dir, 'summary.csv'), mode='a') as cf: dw = csv.DictWriter(cf, fieldnames=rowd.keys()) if best_loss is None: # first iteration (epoch == 1 can't be used) dw.writeheader() dw.writerow(rowd) # save proper checkpoint with eval metric best_loss = saver.save_checkpoint({ 'epoch': epoch + 1, 'arch': args.model, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'args': args, 'gp': args.gp, }, epoch=epoch + 1, metric=eval_metrics['eval_loss']) except KeyboardInterrupt: pass print('*** Best loss: {0} (epoch {1})'.format(best_loss[1], best_loss[0]))
# forward + backward + optimize outputs = net(inputs) #labels += 10 loss = criterion(outputs, labels) loss.backward() optimizer.step() #n_image_total += labels.size()[0] # print statistics running_loss += loss.data[0] if n_image_total % 2000 == 1999: # print every 2000 mini-batches #if i % 2000 == 1999: # print every 2000 mini-batches print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000)) is_best_changed, is_lr_decayed = scheduler.step(running_loss / 2000, n_image_total) # update lr if needed if is_lr_just_decayed and (not is_best_changed): shall_stop = True break is_lr_just_decayed = is_lr_decayed running_loss = 0.0 n_image_total += 1 if shall_stop: break print('Finished Training') dataiter = iter(testloader) images, labels = dataiter.next()
def main(): logger.info("Starting training\n\n") sys.stdout.flush() args = get_args() snapshot_path = args.snapshot_prefix + "-cur_snapshot.pth" best_model_path = args.snapshot_prefix + "-best_model.pth" line_img_transforms = imagetransforms.Compose([ imagetransforms.Scale(new_h=args.line_height), imagetransforms.InvertBlackWhite(), imagetransforms.ToTensor(), ]) # Setup cudnn benchmarks for faster code torch.backends.cudnn.benchmark = False train_dataset = OcrDataset(args.datadir, "train", line_img_transforms) validation_dataset = OcrDataset(args.datadir, "validation", line_img_transforms) train_dataloader = DataLoader(train_dataset, args.batch_size, num_workers=4, sampler=GroupedSampler(train_dataset, rand=True), collate_fn=SortByWidthCollater, pin_memory=True, drop_last=True) validation_dataloader = DataLoader(validation_dataset, args.batch_size, num_workers=0, sampler=GroupedSampler( validation_dataset, rand=False), collate_fn=SortByWidthCollater, pin_memory=False, drop_last=False) n_epochs = args.nepochs lr_alpha = args.lr snapshot_every_n_iterations = args.snapshot_num_iterations if args.load_from_snapshot is not None: model = CnnOcrModel.FromSavedWeights(args.load_from_snapshot) else: model = CnnOcrModel(num_in_channels=1, input_line_height=args.line_height, lstm_input_dim=args.lstm_input_dim, num_lstm_layers=args.num_lstm_layers, num_lstm_hidden_units=args.num_lstm_units, p_lstm_dropout=0.5, alphabet=train_dataset.alphabet, multigpu=True) # Set training mode on all sub-modules model.train() ctc_loss = CTCLoss().cuda() iteration = 0 best_val_wer = float('inf') optimizer = torch.optim.Adam(model.parameters(), lr=lr_alpha) scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=args.patience, min_lr=args.min_lr) wer_array = [] cer_array = [] loss_array = [] lr_points = [] iteration_points = [] epoch_size = len(train_dataloader) for epoch in range(1, n_epochs + 1): epoch_start = datetime.datetime.now() # First modify main OCR model for batch in train_dataloader: sys.stdout.flush() iteration += 1 iteration_start = datetime.datetime.now() loss = train(batch, model, ctc_loss, optimizer) elapsed_time = datetime.datetime.now() - iteration_start loss = loss / args.batch_size loss_array.append(loss) logger.info( "Iteration: %d (%d/%d in epoch %d)\tLoss: %f\tElapsed Time: %s" % (iteration, iteration % epoch_size, epoch_size, epoch, loss, pretty_print_timespan(elapsed_time))) # Do something with loss, running average, plot to some backend server, etc if iteration % snapshot_every_n_iterations == 0: logger.info("Testing on validation set") val_loss, val_cer, val_wer = test_on_val( validation_dataloader, model, ctc_loss) # Reduce learning rate on plateau early_exit = False lowered_lr = False if scheduler.step(val_wer): lowered_lr = True lr_points.append(iteration / snapshot_every_n_iterations) if scheduler.finished: early_exit = True # for bookeeping only lr_alpha = max(lr_alpha * scheduler.factor, scheduler.min_lr) logger.info( "Val Loss: %f\tNo LM Val CER: %f\tNo LM Val WER: %f" % (val_loss, val_cer, val_wer)) torch.save( { 'iteration': iteration, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'model_hyper_params': model.get_hyper_params(), 'cur_lr': lr_alpha, 'val_loss': val_loss, 'val_cer': val_cer, 'val_wer': val_wer, 'line_height': args.line_height }, snapshot_path) # plotting lr_change on wer, cer and loss. wer_array.append(val_wer) cer_array.append(val_cer) iteration_points.append(iteration / snapshot_every_n_iterations) if val_wer < best_val_wer: logger.info( "Best model so far, copying snapshot to best model file" ) best_val_wer = val_wer shutil.copyfile(snapshot_path, best_model_path) logger.info("Running WER: %s" % str(wer_array)) logger.info("Done with validation, moving on.") if early_exit: logger.info("Early exit") sys.exit(0) if lowered_lr: logger.info( "Switching to best model parameters before continuing with lower LR" ) weights = torch.load(best_model_path) model.load_state_dict(weights['state_dict']) elapsed_time = datetime.datetime.now() - epoch_start logger.info("\n------------------") logger.info("Done with epoch, elapsed time = %s" % pretty_print_timespan(elapsed_time)) logger.info("------------------\n") #writer.close() logger.info("Done.")
def main(): logger.info("Starting training\n\n") sys.stdout.flush() args = get_args() snapshot_path = args.snapshot_prefix + "-cur_snapshot.pth" best_model_path = args.snapshot_prefix + "-best_model.pth" line_img_transforms = [] #if args.num_in_channels == 3: # line_img_transforms.append(imagetransforms.ConvertColor()) # Always convert color for the augmentations to work (for now) # Then alter convert back to grayscale if needed line_img_transforms.append(imagetransforms.ConvertColor()) # Data augmentations (during training only) if args.daves_augment: line_img_transforms.append(daves_augment.ImageAug()) if args.synth_input: # Randomly rotate image from -2 degrees to +2 degrees line_img_transforms.append( imagetransforms.Randomize(0.3, imagetransforms.RotateRandom(-2, 2))) # Choose one of methods to blur/pixel-ify image (or don't and choose identity) line_img_transforms.append( imagetransforms.PickOne([ imagetransforms.TessBlockConv(kernel_val=1, bias_val=1), imagetransforms.TessBlockConv(rand=True), imagetransforms.Identity(), ])) aug_cn = iaa.ContrastNormalization((0.5, 2.0), per_channel=0.5) line_img_transforms.append( imagetransforms.Randomize(0.5, lambda x: aug_cn.augment_image(x))) # With some probability, choose one of: # Grayscale: convert to grayscale and add back into color-image with random alpha # Emboss: Emboss image with random strength # Invert: Invert colors of image per-channel aug_gray = iaa.Grayscale(alpha=(0.0, 1.0)) aug_emboss = iaa.Emboss(alpha=(0, 1.0), strength=(0, 2.0)) aug_invert = iaa.Invert(1, per_channel=True) aug_invert2 = iaa.Invert(0.1, per_channel=False) line_img_transforms.append( imagetransforms.Randomize( 0.3, imagetransforms.PickOne([ lambda x: aug_gray.augment_image(x), lambda x: aug_emboss.augment_image(x), lambda x: aug_invert.augment_image(x), lambda x: aug_invert2.augment_image(x) ]))) # Randomly try to crop close to top/bottom and left/right of lines # For now we are just guessing (up to 5% of ends and up to 10% of tops/bottoms chopped off) if args.tight_crop: # To make sure padding is reasonably consistent, we first rsize image to target line height # Then add padding to this version of image # Below it will get resized again to target line height line_img_transforms.append( imagetransforms.Randomize( 0.9, imagetransforms.Compose([ imagetransforms.Scale(new_h=args.line_height), imagetransforms.PadRandom(pxl_max_horizontal=30, pxl_max_vertical=10) ]))) else: line_img_transforms.append( imagetransforms.Randomize(0.2, imagetransforms.CropHorizontal(.05))) line_img_transforms.append( imagetransforms.Randomize(0.2, imagetransforms.CropVertical(.1))) #line_img_transforms.append(imagetransforms.Randomize(0.2, # imagetransforms.PickOne([imagetransforms.MorphErode(3), imagetransforms.MorphDilate(3)]) # )) # Make sure to do resize after degrade step above line_img_transforms.append(imagetransforms.Scale(new_h=args.line_height)) if args.cvtGray: line_img_transforms.append(imagetransforms.ConvertGray()) # Only do for grayscale if args.num_in_channels == 1: line_img_transforms.append(imagetransforms.InvertBlackWhite()) if args.stripe: line_img_transforms.append( imagetransforms.Randomize( 0.3, imagetransforms.AddRandomStripe(val=0, strip_width_from=1, strip_width_to=4))) line_img_transforms.append(imagetransforms.ToTensor()) line_img_transforms = imagetransforms.Compose(line_img_transforms) # Setup cudnn benchmarks for faster code torch.backends.cudnn.benchmark = False if len(args.datadir) == 1: train_dataset = OcrDataset(args.datadir[0], "train", line_img_transforms) validation_dataset = OcrDataset(args.datadir[0], "validation", line_img_transforms) else: train_dataset = OcrDatasetUnion(args.datadir, "train", line_img_transforms) validation_dataset = OcrDatasetUnion(args.datadir, "validation", line_img_transforms) if args.test_datadir is not None: if args.test_outdir is None: print( "Error, must specify both --test-datadir and --test-outdir together" ) sys.exit(1) if not os.path.exists(args.test_outdir): os.makedirs(args.test_outdir) line_img_transforms_test = imagetransforms.Compose([ imagetransforms.Scale(new_h=args.line_height), imagetransforms.ToTensor() ]) test_dataset = OcrDataset(args.test_datadir, "test", line_img_transforms_test) n_epochs = args.nepochs lr_alpha = args.lr snapshot_every_n_iterations = args.snapshot_num_iterations if args.load_from_snapshot is not None: model = CnnOcrModel.FromSavedWeights(args.load_from_snapshot) print( "Overriding automatically learned alphabet with pre-saved model alphabet" ) if len(args.datadir) == 1: train_dataset.alphabet = model.alphabet validation_dataset.alphabet = model.alphabet else: train_dataset.alphabet = model.alphabet validation_dataset.alphabet = model.alphabet for ds in train_dataset.datasets: ds.alphabet = model.alphabet for ds in validation_dataset.datasets: ds.alphabet = model.alphabet else: model = CnnOcrModel(num_in_channels=args.num_in_channels, input_line_height=args.line_height, rds_line_height=args.rds_line_height, lstm_input_dim=args.lstm_input_dim, num_lstm_layers=args.num_lstm_layers, num_lstm_hidden_units=args.num_lstm_units, p_lstm_dropout=0.5, alphabet=train_dataset.alphabet, multigpu=True) # Setting dataloader after we have a chnae to (maybe!) over-ride the dataset alphabet from a pre-trained model train_dataloader = DataLoader(train_dataset, args.batch_size, num_workers=4, sampler=GroupedSampler(train_dataset, rand=True), collate_fn=SortByWidthCollater, pin_memory=True, drop_last=True) if args.max_val_size > 0: validation_dataloader = DataLoader(validation_dataset, args.batch_size, num_workers=0, sampler=GroupedSampler( validation_dataset, max_items=args.max_val_size, fixed_rand=True), collate_fn=SortByWidthCollater, pin_memory=False, drop_last=False) else: validation_dataloader = DataLoader(validation_dataset, args.batch_size, num_workers=0, sampler=GroupedSampler( validation_dataset, rand=False), collate_fn=SortByWidthCollater, pin_memory=False, drop_last=False) if args.test_datadir is not None: test_dataloader = DataLoader(test_dataset, args.batch_size, num_workers=0, sampler=GroupedSampler(test_dataset, rand=False), collate_fn=SortByWidthCollater, pin_memory=False, drop_last=False) # Set training mode on all sub-modules model.train() ctc_loss = CTCLoss().cuda() iteration = 0 best_val_wer = float('inf') optimizer = torch.optim.Adam(model.parameters(), lr=lr_alpha, weight_decay=args.weight_decay) scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=args.patience, min_lr=args.min_lr) wer_array = [] cer_array = [] loss_array = [] lr_points = [] iteration_points = [] epoch_size = len(train_dataloader) do_test_write = False for epoch in range(1, n_epochs + 1): epoch_start = datetime.datetime.now() # First modify main OCR model for batch in train_dataloader: sys.stdout.flush() iteration += 1 iteration_start = datetime.datetime.now() loss = train(batch, model, ctc_loss, optimizer) elapsed_time = datetime.datetime.now() - iteration_start loss = loss / args.batch_size loss_array.append(loss) logger.info( "Iteration: %d (%d/%d in epoch %d)\tLoss: %f\tElapsed Time: %s" % (iteration, iteration % epoch_size, epoch_size, epoch, loss, pretty_print_timespan(elapsed_time))) # Only turn on test-on-testset when cer is starting to get non-random if iteration % snapshot_every_n_iterations == 0: logger.info("Testing on validation set") val_loss, val_cer, val_wer = test_on_val( validation_dataloader, model, ctc_loss) if val_cer < 0.5: do_test_write = True if args.test_datadir is not None and ( iteration % snapshot_every_n_iterations == 0) and do_test_write: out_hyp_outdomain_file = os.path.join( args.test_outdir, "hyp-%07d.outdomain.utf8" % iteration) out_hyp_indomain_file = os.path.join( args.test_outdir, "hyp-%07d.indomain.utf8" % iteration) out_meta_file = os.path.join(args.test_outdir, "hyp-%07d.meta" % iteration) test_on_val_writeout(test_dataloader, model, out_hyp_outdomain_file) test_on_val_writeout(validation_dataloader, model, out_hyp_indomain_file) with open(out_meta_file, 'w') as fh_out: fh_out.write("%d,%f,%f,%f\n" % (iteration, val_cer, val_wer, val_loss)) # Reduce learning rate on plateau early_exit = False lowered_lr = False if scheduler.step(val_wer): lowered_lr = True lr_points.append(iteration / snapshot_every_n_iterations) if scheduler.finished: early_exit = True # for bookeeping only lr_alpha = max(lr_alpha * scheduler.factor, scheduler.min_lr) logger.info( "Val Loss: %f\tNo LM Val CER: %f\tNo LM Val WER: %f" % (val_loss, val_cer, val_wer)) torch.save( { 'iteration': iteration, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'model_hyper_params': model.get_hyper_params(), 'rtl': args.rtl, 'cur_lr': lr_alpha, 'val_loss': val_loss, 'val_cer': val_cer, 'val_wer': val_wer, 'line_height': args.line_height }, snapshot_path) # plotting lr_change on wer, cer and loss. wer_array.append(val_wer) cer_array.append(val_cer) iteration_points.append(iteration / snapshot_every_n_iterations) if val_wer < best_val_wer: logger.info( "Best model so far, copying snapshot to best model file" ) best_val_wer = val_wer shutil.copyfile(snapshot_path, best_model_path) logger.info("Running WER: %s" % str(wer_array)) logger.info("Done with validation, moving on.") if early_exit: logger.info("Early exit") sys.exit(0) if lowered_lr: logger.info( "Switching to best model parameters before continuing with lower LR" ) weights = torch.load(best_model_path) model.load_state_dict(weights['state_dict']) elapsed_time = datetime.datetime.now() - epoch_start logger.info("\n------------------") logger.info("Done with epoch, elapsed time = %s" % pretty_print_timespan(elapsed_time)) logger.info("------------------\n") #writer.close() logger.info("Done.")