def trainIters(args): charSet = CharSet(args['LANGUAGE']) watch = Watch(args['LAYER_SIZE'], args['HIDDEN_SIZE'], args['HIDDEN_SIZE']) spell = Spell(args['LAYER_SIZE'], args['HIDDEN_SIZE'], charSet.get_total_num()) # watch = nn.DataParallel(watch) # spell = nn.DataParallel(spell) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") watch = watch.to(device) spell = spell.to(device) watch_optimizer = optim.Adam(watch.parameters(), lr=args['LEARNING_RATE']) spell_optimizer = optim.Adam(spell.parameters(), lr=args['LEARNING_RATE']) watch_scheduler = optim.lr_scheduler.StepLR( watch_optimizer, step_size=args['LEARNING_RATE_DECAY_EPOCH'], gamma=args['LEARNING_RATE_DECAY_RATIO']) spell_scheduler = optim.lr_scheduler.StepLR( spell_optimizer, step_size=args['LEARNING_RATE_DECAY_EPOCH'], gamma=args['LEARNING_RATE_DECAY_RATIO']) criterion = nn.CrossEntropyLoss(ignore_index=charSet.get_index_of('<pad>')) train_loader, eval_loader = get_dataloaders(args['PATH'], args['BS'], args['VMAX'], args['TMAX'], args['WORKER'], charSet, args['VALIDATION_RATIO']) # train_loader = DataLoader(dataset=dataset, # batch_size=batch_size, # shuffle=True) total_batch = len(train_loader) total_eval_batch = len(eval_loader) for epoch in range(args['ITER']): avg_loss = 0.0 avg_eval_loss = 0.0 avg_cer = 0.0 avg_eval_cer = 0.0 watch_scheduler.step() spell_scheduler.step() watch = watch.train() spell = spell.train() for i, (data, labels) in enumerate(train_loader): loss, cer = train(data, labels, watch, spell, watch_optimizer, spell_optimizer, criterion, True, charSet) avg_loss += loss avg_cer += cer print('Batch : ', i + 1, '/', total_batch, ', ERROR in this minibatch: ', loss) print('Character error rate : ', cer) watch = watch.eval() spell = spell.eval() for k, (data, labels) in enumerate(eval_loader): loss, cer = train(data, labels, watch, spell, watch_optimizer, spell_optimizer, criterion, False, charSet) avg_eval_loss += loss avg_eval_cer += cer print('epoch:', epoch, ' train_loss:', float(avg_loss / total_batch)) print('epoch:', epoch, ' Average CER:', float(avg_cer / total_batch)) print('epoch:', epoch, ' Validation_loss:', float(avg_eval_loss / total_eval_batch)) print('epoch:', epoch, ' Average CER:', float(avg_eval_cer / total_eval_batch)) if epoch % args['SAVE_EVERY'] == 0 and epoch != 0: torch.save(watch, 'watch{}.pt'.format(epoch)) torch.save(spell, 'spell{}.pt'.format(epoch))
def trainIters(n_iters, videomax, txtmax, data_path, batch_size, worker, ratio_of_validation=0.0001, learning_rate_decay=2000, save_every=30, learning_rate=0.01): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") watch = Watch(3, 512, 512) spell = Spell(num_layers=3, output_size=len(int_list), hidden_size=512) watch = watch.to(device) spell = spell.to(device) watch_optimizer = optim.Adam(watch.parameters(), lr=learning_rate) spell_optimizer = optim.Adam(spell.parameters(), lr=learning_rate) watch_scheduler = optim.lr_scheduler.StepLR(watch_optimizer, step_size=learning_rate_decay, gamma=0.1) spell_scheduler = optim.lr_scheduler.StepLR(spell_optimizer, step_size=learning_rate_decay, gamma=0.1) criterion = nn.CrossEntropyLoss(ignore_index=38) train_loader, eval_loader = get_dataloaders( data_path, batch_size, videomax, txtmax, worker, ratio_of_validation=ratio_of_validation) # train_loader = DataLoader(dataset=dataset, # batch_size=batch_size, # shuffle=True) total_batch = len(train_loader) total_eval_batch = len(eval_loader) for epoch in range(n_iters): avg_loss = 0.0 avg_eval_loss = 0.0 watch_scheduler.step() spell_scheduler.step() watch = watch.train() spell = spell.train() for i, (data, labels) in enumerate(train_loader): loss = train(data.to(device), labels.to(device), watch, spell, watch_optimizer, spell_optimizer, criterion, True) avg_loss += loss print('Batch : ', i + 1, '/', total_batch, ', ERROR in this minibatch: ', loss) del data, labels, loss watch = watch.eval() spell = spell.eval() for k, (data, labels) in enumerate(eval_loader): loss = train(data.to(device), labels.to(device), watch, spell, watch_optimizer, spell_optimizer, criterion, False) avg_eval_loss += loss print('Batch : ', i + 1, '/', total_batch, ', Validation ERROR in this minibatch: ', loss) del data, labels, loss print('epoch:', epoch, ' train_loss:', float(avg_loss / total_batch)) print('epoch:', epoch, ' eval_loss:', float(avg_eval_loss / total_eval_batch)) if epoch % save_every == 0 and epoch != 0: torch.save(watch, 'watch{}.pt'.format(epoch)) torch.save(spell, 'spell{}.pt'.format(epoch))