def main(): opts = parse_args() init_logging(os.path.join(opts.log_dir, '{:s}_log.txt'.format(opts.task))) if torch.cuda.is_available(): torch.cuda.set_device(opts.gpu) logging.info("Using GPU!") device = "cuda" else: logging.info("Using CPU!") device = "cpu" logging.info(opts) train_datasets = PhoenixVideo(opts.vocab_file, opts.corpus_dir, opts.video_path, phase="train", DEBUG=opts.DEBUG) valid_datasets = PhoenixVideo(opts.vocab_file, opts.corpus_dir, opts.video_path, phase="dev", DEBUG=opts.DEBUG) vocab_size = valid_datasets.vocab.num_words blank_id = valid_datasets.vocab.word2index['<BLANK>'] vocabulary = Vocabulary(opts.vocab_file) #model = DilatedSLRNet(opts, device, vocab_size, vocabulary, # dilated_channels=512, num_blocks=5, dilations=[1, 2, 4], dropout=0.0) model = MainStream(vocab_size) criterion = CtcLoss(opts, blank_id, device, reduction="none") # print(model) # Build trainer trainer = Trainer(opts, model, criterion, vocabulary, vocab_size, blank_id) if os.path.exists(opts.check_point): logging.info("Loading checkpoint file from {}".format( opts.check_point)) epoch, num_updates, loss = trainer.load_checkpoint(opts.check_point) else: logging.info("No checkpoint file in found in {}".format( opts.check_point)) epoch, num_updates, loss = 0, 0, 0.0 trainer.set_num_updates(num_updates) model_manager = ModelManager(max_num_models=5) while epoch < opts.max_epoch and trainer.get_num_updates( ) < opts.max_updates: epoch += 1 trainer.adjust_learning_rate(epoch) #trainer.dynamic_freeze_layers(epoch) loss = train(opts, train_datasets, valid_datasets, trainer, epoch, num_updates, loss) #if num_updates % opts.save_interval_updates == 0: if epoch <= opts.stage_epoch * 2: phoenix_eval_err = eval(opts, valid_datasets, trainer, epoch) phoenix_eval_err = eval_tf(opts, valid_datasets, trainer, epoch) else: phoenix_eval_err = eval(opts, valid_datasets, trainer, epoch) phoenix_eval_err = eval_dec(opts, valid_datasets, trainer, epoch) save_ckpt = os.path.join( opts.log_dir, 'ep{:d}_{:.4f}.pkl'.format(epoch, phoenix_eval_err[0])) trainer.save_checkpoint(save_ckpt, epoch, num_updates, loss) model_manager.update(save_ckpt, phoenix_eval_err, epoch)
def main(): opts = parse_args() setup_seed(opts.seed) init_logging( os.path.join(opts.log_dir, '{:s}_seed{}_log.txt'.format(opts.task, opts.seed))) if torch.cuda.is_available(): torch.cuda.set_device(opts.gpu) logging.info("Using GPU!") device = "cuda" else: logging.info("Using CPU!") device = "cpu" logging.info(opts) train_datasets = PhoenixVideo(opts.vocab_file, opts.corpus_dir, opts.video_path, phase="train", DEBUG=opts.DEBUG) valid_datasets = PhoenixVideo(opts.vocab_file, opts.corpus_dir, opts.video_path, phase="dev", DEBUG=opts.DEBUG) vocab_size = valid_datasets.vocab.num_words blank_id = valid_datasets.vocab.word2index['<BLANK>'] vocabulary = Vocabulary(opts.vocab_file) model = MainStream(vocab_size, opts.bn_momentum) criterion = CtcLoss(opts, blank_id, device, reduction="none") logging.info(model) # Build trainer trainer = Trainer(opts, model, criterion, vocabulary, vocab_size, blank_id) if os.path.exists(opts.check_point): logging.info("Loading checkpoint file from {}".format( opts.check_point)) epoch, num_updates, loss = trainer.load_checkpoint(opts.check_point) elif os.path.exists(opts.pretrain): logging.info("Loading checkpoint file from {}".format(opts.pretrain)) trainer.pretrain(opts) epoch, num_updates, loss = 0, 0, 0.0 else: logging.info("No checkpoint file in found in {}".format( opts.check_point)) epoch, num_updates, loss = 0, 0, 0.0 logging.info('| num. module params: {} (num. trained: {})'.format( sum(p.numel() for p in model.parameters()), sum(p.numel() for p in model.parameters() if p.requires_grad), )) trainer.set_num_updates(num_updates) model_manager = ModelManager(max_num_models=25) while epoch < opts.max_epoch and trainer.get_num_updates( ) < opts.max_updates: epoch += 1 trainer.adjust_learning_rate(epoch) loss = train(opts, train_datasets, valid_datasets, trainer, epoch, num_updates, loss) if epoch <= opts.stage_epoch: eval_train(opts, train_datasets, trainer, epoch) # phoenix_eval_err = eval_tf(opts, valid_datasets, trainer, epoch) phoenix_eval_err = eval(opts, valid_datasets, trainer, epoch) else: # eval_train(opts, train_datasets, trainer, epoch) phoenix_eval_err = eval(opts, valid_datasets, trainer, epoch) save_ckpt = os.path.join( opts.log_dir, 'ep{:d}_{:.4f}.pkl'.format(epoch, phoenix_eval_err[0])) trainer.save_checkpoint(save_ckpt, epoch, num_updates, loss) model_manager.update(save_ckpt, phoenix_eval_err, epoch)