def train(config): training_set_list = load_file_list(pretrain_config['training_set']) train_dataset = SolDataset( training_set_list, rescale_range=pretrain_config['sol']['training_rescale_range'], transform=CropTransform(pretrain_config['sol']['crop_params'])) train_dataloader = DataLoader( train_dataset, batch_size=pretrain_config['sol']['batch_size'], shuffle=True, num_workers=0, collate_fn=sol.sol_dataset.collate) batches_per_epoch = int(pretrain_config['sol']['images_per_epoch'] / pretrain_config['sol']['batch_size']) train_dataloader = DatasetWrapper(train_dataloader, batches_per_epoch) if not os.path.exists("snapshots/sol_train"): os.makedirs("snapshots/sol_train") solf.train() sum_loss = 0.0 steps = 0.0 for step_i, x in enumerate(train_dataloader): img = Variable(x['img'].type(dtype), requires_grad=False) sol_gt = None if x['sol_gt'] is not None: # This is needed because if sol_gt is None it means that there # no GT positions in the image. The alignment loss will handle, # it correctly as None sol_gt = Variable(x['sol_gt'].type(dtype), requires_grad=False) # print((img.shape)) predictions = solf(img) loss = alignment_loss(predictions, sol_gt, x['label_sizes'], alpha_alignment, alpha_backprop) org_img = img[0].data.cpu().numpy().transpose([2, 1, 0]) org_img = ((org_img + 1) * 128).astype(np.uint8) org_img = org_img.copy() org_img = drawing.draw_sol_torch(predictions, org_img) cv2.imwrite("snapshots/sol_train/{}.png".format(step_i), org_img) optimizer.zero_grad() loss.backward() optimizer.step() sum_loss += loss.data.cpu().numpy() steps += 1 predictions = None loss = None gc.collect() return sum_loss / steps
def training_step(config): train_config = config['training'] allowed_training_time = train_config['sol']['reset_interval'] init_training_time = time.time() training_set_list = load_file_list(train_config['training_set']) train_dataset = SolDataset( training_set_list, rescale_range=train_config['sol']['training_rescale_range'], transform=CropTransform(train_config['sol']['crop_params'])) train_dataloader = DataLoader(train_dataset, batch_size=train_config['sol']['batch_size'], shuffle=True, num_workers=0, collate_fn=sol_dataset.collate) batches_per_epoch = int(train_config['sol']['images_per_epoch'] / train_config['sol']['batch_size']) train_dataloader = DatasetWrapper(train_dataloader, batches_per_epoch) test_set_list = load_file_list(train_config['validation_set']) test_dataset = SolDataset( test_set_list, rescale_range=train_config['sol']['validation_rescale_range'], random_subset_size=train_config['sol']['validation_subset_size'], transform=None) test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0, collate_fn=sol_dataset.collate) alpha_alignment = train_config['sol']['alpha_alignment'] alpha_backprop = train_config['sol']['alpha_backprop'] sol, lf, hw = init_model(config, only_load='sol') dtype = torch.cuda.FloatTensor lowest_loss = np.inf lowest_loss_i = 0 epoch = -1 while True: #This ends on a break based on the current itme epoch += 1 print "Train Time:", ( time.time() - init_training_time), "Allowed Time:", allowed_training_time sol.eval() sum_loss = 0.0 steps = 0.0 start_time = time.time() for step_i, x in enumerate(test_dataloader): img = Variable(x['img'].type(dtype), requires_grad=False) sol_gt = None if x['sol_gt'] is not None: sol_gt = Variable(x['sol_gt'].type(dtype), requires_grad=False) predictions = sol(img) predictions = transformation_utils.pt_xyrs_2_xyxy(predictions) loss = alignment_loss(predictions, sol_gt, x['label_sizes'], alpha_alignment, alpha_backprop) sum_loss += loss.data[0] steps += 1 if epoch == 0: print "First Validation Step Complete" print "Benchmark Validation CER:", sum_loss / steps lowest_loss = sum_loss / steps sol, lf, hw = init_model(config, sol_dir='current', only_load='sol') optimizer = torch.optim.Adam( sol.parameters(), lr=train_config['sol']['learning_rate']) optim_path = os.path.join(train_config['snapshot']['current'], "sol_optim.pt") if os.path.exists(optim_path): print "Loading Optim Settings" optimizer.load_state_dict(safe_load.torch_state(optim_path)) else: print "Failed to load Optim Settings" elif lowest_loss > sum_loss / steps: lowest_loss = sum_loss / steps print "Saving Best" dirname = train_config['snapshot']['best_validation'] if not len(dirname) != 0 and os.path.exists(dirname): os.makedirs(dirname) save_path = os.path.join(dirname, "sol.pt") torch.save(sol.state_dict(), save_path) lowest_loss_i = epoch print "Test Loss", sum_loss / steps, lowest_loss print "Time:", time.time() - start_time print "" print "Epoch", epoch if allowed_training_time < (time.time() - init_training_time): print "Out of time. Saving current state and exiting..." dirname = train_config['snapshot']['current'] if not len(dirname) != 0 and os.path.exists(dirname): os.makedirs(dirname) save_path = os.path.join(dirname, "sol.pt") torch.save(sol.state_dict(), save_path) optim_path = os.path.join(dirname, "sol_optim.pt") torch.save(optimizer.state_dict(), optim_path) break sol.train() sum_loss = 0.0 steps = 0.0 start_time = time.time() for step_i, x in enumerate(train_dataloader): img = Variable(x['img'].type(dtype), requires_grad=False) sol_gt = None if x['sol_gt'] is not None: sol_gt = Variable(x['sol_gt'].type(dtype), requires_grad=False) predictions = sol(img) predictions = transformation_utils.pt_xyrs_2_xyxy(predictions) loss = alignment_loss(predictions, sol_gt, x['label_sizes'], alpha_alignment, alpha_backprop) optimizer.zero_grad() loss.backward() optimizer.step() sum_loss += loss.data[0] steps += 1 print "Train Loss", sum_loss / steps print "Real Epoch", train_dataloader.epoch print "Time:", time.time() - start_time
training_set_list = load_file_list(pretrain_config['training_set']) train_dataset = SolDataset( training_set_list, rescale_range=pretrain_config['sol']['training_rescale_range'], transform=CropTransform(pretrain_config['sol']['crop_params'])) train_dataloader = DataLoader(train_dataset, batch_size=pretrain_config['sol']['batch_size'], shuffle=True, num_workers=0, collate_fn=sol.sol_dataset.collate) batches_per_epoch = int(pretrain_config['sol']['images_per_epoch'] / pretrain_config['sol']['batch_size']) train_dataloader = DatasetWrapper(train_dataloader, batches_per_epoch) test_set_list = load_file_list(pretrain_config['validation_set']) test_dataset = SolDataset( test_set_list, rescale_range=pretrain_config['sol']['validation_rescale_range'], transform=None) test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0, collate_fn=sol.sol_dataset.collate) base0 = sol_network_config['base0'] base1 = sol_network_config['base1'] sol = StartOfLineFinder(base0, base1)
def training_step(config): hw_network_config = config['network']['hw'] train_config = config['training'] allowed_training_time = train_config['hw']['reset_interval'] init_training_time = time.time() char_set_path = hw_network_config['char_set_path'] with open(char_set_path) as f: char_set = json.load(f) idx_to_char = {} for k, v in char_set['idx_to_char'].iteritems(): idx_to_char[int(k)] = v training_set_list = load_file_list(train_config['training_set']) train_dataset = HwDataset(training_set_list, char_set['char_to_idx'], augmentation=True, img_height=hw_network_config['input_height']) train_dataloader = DataLoader(train_dataset, batch_size=train_config['hw']['batch_size'], shuffle=False, num_workers=0, collate_fn=hw_dataset.collate) batches_per_epoch = int(train_config['hw']['images_per_epoch'] / train_config['hw']['batch_size']) train_dataloader = DatasetWrapper(train_dataloader, batches_per_epoch) test_set_list = load_file_list(train_config['validation_set']) test_dataset = HwDataset( test_set_list, char_set['char_to_idx'], img_height=hw_network_config['input_height'], random_subset_size=train_config['hw']['validation_subset_size']) test_dataloader = DataLoader(test_dataset, batch_size=train_config['hw']['batch_size'], shuffle=False, num_workers=0, collate_fn=hw_dataset.collate) hw = cnn_lstm.create_model(hw_network_config) hw_path = os.path.join(train_config['snapshot']['best_validation'], "hw.pt") hw_state = safe_load.torch_state(hw_path) hw.load_state_dict(hw_state) hw.cuda() criterion = CTCLoss() dtype = torch.cuda.FloatTensor lowest_loss = np.inf lowest_loss_i = 0 for epoch in xrange(10000000000): sum_loss = 0.0 steps = 0.0 hw.eval() for x in test_dataloader: sys.stdout.flush() line_imgs = Variable(x['line_imgs'].type(dtype), requires_grad=False, volatile=True) labels = Variable(x['labels'], requires_grad=False, volatile=True) label_lengths = Variable(x['label_lengths'], requires_grad=False, volatile=True) preds = hw(line_imgs).cpu() output_batch = preds.permute(1, 0, 2) out = output_batch.data.cpu().numpy() for i, gt_line in enumerate(x['gt']): logits = out[i, ...] pred, raw_pred = string_utils.naive_decode(logits) pred_str = string_utils.label2str_single( pred, idx_to_char, False) cer = error_rates.cer(gt_line, pred_str) sum_loss += cer steps += 1 if epoch == 0: print "First Validation Step Complete" print "Benchmark Validation CER:", sum_loss / steps lowest_loss = sum_loss / steps hw = cnn_lstm.create_model(hw_network_config) hw_path = os.path.join(train_config['snapshot']['current'], "hw.pt") hw_state = safe_load.torch_state(hw_path) hw.load_state_dict(hw_state) hw.cuda() optimizer = torch.optim.Adam( hw.parameters(), lr=train_config['hw']['learning_rate']) optim_path = os.path.join(train_config['snapshot']['current'], "hw_optim.pt") if os.path.exists(optim_path): print "Loading Optim Settings" optimizer.load_state_dict(safe_load.torch_state(optim_path)) else: print "Failed to load Optim Settings" if lowest_loss > sum_loss / steps: lowest_loss = sum_loss / steps print "Saving Best" dirname = train_config['snapshot']['best_validation'] if not len(dirname) != 0 and os.path.exists(dirname): os.makedirs(dirname) save_path = os.path.join(dirname, "hw.pt") torch.save(hw.state_dict(), save_path) lowest_loss_i = epoch print "Test Loss", sum_loss / steps, lowest_loss print "" if allowed_training_time < (time.time() - init_training_time): print "Out of time: Exiting..." break print "Epoch", epoch sum_loss = 0.0 steps = 0.0 hw.train() for i, x in enumerate(train_dataloader): line_imgs = Variable(x['line_imgs'].type(dtype), requires_grad=False) labels = Variable(x['labels'], requires_grad=False) label_lengths = Variable(x['label_lengths'], requires_grad=False) preds = hw(line_imgs).cpu() output_batch = preds.permute(1, 0, 2) out = output_batch.data.cpu().numpy() # if i == 0: # for i in xrange(out.shape[0]): # pred, pred_raw = string_utils.naive_decode(out[i,...]) # pred_str = string_utils.label2str_single(pred_raw, idx_to_char, True) # print pred_str for i, gt_line in enumerate(x['gt']): logits = out[i, ...] pred, raw_pred = string_utils.naive_decode(logits) pred_str = string_utils.label2str_single( pred, idx_to_char, False) cer = error_rates.cer(gt_line, pred_str) sum_loss += cer steps += 1 batch_size = preds.size(1) preds_size = Variable(torch.IntTensor([preds.size(0)] * batch_size)) loss = criterion(preds, labels, preds_size, label_lengths) optimizer.zero_grad() loss.backward() optimizer.step() print "Train Loss", sum_loss / steps print "Real Epoch", train_dataloader.epoch ## Save current snapshots for next iteration print "Saving Current" dirname = train_config['snapshot']['current'] if not len(dirname) != 0 and os.path.exists(dirname): os.makedirs(dirname) save_path = os.path.join(dirname, "hw.pt") torch.save(hw.state_dict(), save_path) optim_path = os.path.join(dirname, "hw_optim.pt") torch.save(optimizer.state_dict(), optim_path)
def training_step(config): char_set_path = config['network']['hw']['char_set_path'] with open(char_set_path) as f: char_set = json.load(f) idx_to_char = {} for k, v in char_set['idx_to_char'].iteritems(): idx_to_char[int(k)] = v train_config = config['training'] allowed_training_time = train_config['lf']['reset_interval'] init_training_time = time.time() training_set_list = load_file_list(train_config['training_set']) train_dataset = LfDataset(training_set_list, augmentation=True) train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=0, collate_fn=lf_dataset.collate) batches_per_epoch = int(train_config['lf']['images_per_epoch'] / train_config['lf']['batch_size']) train_dataloader = DatasetWrapper(train_dataloader, batches_per_epoch) test_set_list = load_file_list(train_config['validation_set']) test_dataset = LfDataset( test_set_list, random_subset_size=train_config['lf']['validation_subset_size']) test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0, collate_fn=lf_dataset.collate) _, lf, hw = init_model(config, only_load=['lf', 'hw']) hw.eval() dtype = torch.cuda.FloatTensor lowest_loss = np.inf lowest_loss_i = 0 for epoch in xrange(10000000): lf.eval() sum_loss = 0.0 steps = 0.0 start_time = time.time() for step_i, x in enumerate(test_dataloader): if x is None: continue #Only single batch for now x = x[0] if x is None: continue positions = [ Variable(x_i.type(dtype), requires_grad=False)[None, ...] for x_i in x['lf_xyrs'] ] xy_positions = [ Variable(x_i.type(dtype), requires_grad=False)[None, ...] for x_i in x['lf_xyxy'] ] img = Variable(x['img'].type(dtype), requires_grad=False)[None, ...] #There might be a way to handle this case later, #but for now we will skip it if len(xy_positions) <= 1: print "Skipping" continue grid_line, _, _, xy_output = lf(img, positions[:1], steps=len(positions), skip_grid=False) line = torch.nn.functional.grid_sample(img.transpose(2, 3), grid_line) line = line.transpose(2, 3) predictions = hw(line) out = predictions.permute(1, 0, 2).data.cpu().numpy() gt_line = x['gt'] pred, raw_pred = string_utils.naive_decode(out[0]) pred_str = string_utils.label2str_single(pred, idx_to_char, False) cer = error_rates.cer(gt_line, pred_str) sum_loss += cer steps += 1 # l = line[0].transpose(0,1).transpose(1,2) # l = (l + 1)*128 # l_np = l.data.cpu().numpy() # # cv2.imwrite("example_line_out.png", l_np) # print "Saved!" # raw_input() # loss = lf_loss.point_loss(xy_output, xy_positions) # # sum_loss += loss.data[0] # steps += 1 if epoch == 0: print "First Validation Step Complete" print "Benchmark Validation Loss:", sum_loss / steps lowest_loss = sum_loss / steps _, lf, _ = init_model(config, lf_dir='current', only_load="lf") optimizer = torch.optim.Adam( lf.parameters(), lr=train_config['lf']['learning_rate']) optim_path = os.path.join(train_config['snapshot']['current'], "lf_optim.pt") if os.path.exists(optim_path): print "Loading Optim Settings" optimizer.load_state_dict(safe_load.torch_state(optim_path)) else: print "Failed to load Optim Settings" if lowest_loss > sum_loss / steps: lowest_loss = sum_loss / steps print "Saving Best" dirname = train_config['snapshot']['best_validation'] if not len(dirname) != 0 and os.path.exists(dirname): os.makedirs(dirname) save_path = os.path.join(dirname, "lf.pt") torch.save(lf.state_dict(), save_path) lowest_loss_i = 0 test_loss = sum_loss / steps print "Test Loss", sum_loss / steps, lowest_loss print "Time:", time.time() - start_time print "" if allowed_training_time < (time.time() - init_training_time): print "Out of time: Exiting..." break print "Epoch", epoch sum_loss = 0.0 steps = 0.0 lf.train() start_time = time.time() for x in train_dataloader: if x is None: continue #Only single batch for now x = x[0] if x is None: continue positions = [ Variable(x_i.type(dtype), requires_grad=False)[None, ...] for x_i in x['lf_xyrs'] ] xy_positions = [ Variable(x_i.type(dtype), requires_grad=False)[None, ...] for x_i in x['lf_xyxy'] ] img = Variable(x['img'].type(dtype), requires_grad=False)[None, ...] #There might be a way to handle this case later, #but for now we will skip it if len(xy_positions) <= 1: continue reset_interval = 4 grid_line, _, _, xy_output = lf(img, positions[:1], steps=len(positions), all_positions=positions, reset_interval=reset_interval, randomize=True, skip_grid=True) loss = lf_loss.point_loss(xy_output, xy_positions) optimizer.zero_grad() loss.backward() optimizer.step() sum_loss += loss.data.item() steps += 1 print "Train Loss", sum_loss / steps print "Real Epoch", train_dataloader.epoch print "Time:", time.time() - start_time ## Save current snapshots for next iteration print "Saving Current" dirname = train_config['snapshot']['current'] if not len(dirname) != 0 and os.path.exists(dirname): os.makedirs(dirname) save_path = os.path.join(dirname, "lf.pt") torch.save(lf.state_dict(), save_path) optim_path = os.path.join(dirname, "lf_optim.pt") torch.save(optimizer.state_dict(), optim_path)