Esempio n. 1
0
def main():
    opts = parse_args()
    init_logging(os.path.join(opts.log_dir, '{:s}_log.txt'.format(opts.task)))

    if torch.cuda.is_available():
        torch.cuda.set_device(opts.gpu)
        logging.info("Using GPU!")
        device = "cuda"
    else:
        logging.info("Using CPU!")
        device = "cpu"

    logging.info(opts)

    train_datasets = PhoenixVideo(opts.vocab_file,
                                  opts.corpus_dir,
                                  opts.video_path,
                                  phase="train",
                                  DEBUG=opts.DEBUG)
    valid_datasets = PhoenixVideo(opts.vocab_file,
                                  opts.corpus_dir,
                                  opts.video_path,
                                  phase="dev",
                                  DEBUG=opts.DEBUG)
    vocab_size = valid_datasets.vocab.num_words
    blank_id = valid_datasets.vocab.word2index['<BLANK>']
    vocabulary = Vocabulary(opts.vocab_file)
    #model = DilatedSLRNet(opts, device, vocab_size, vocabulary,
    #                      dilated_channels=512, num_blocks=5, dilations=[1, 2, 4], dropout=0.0)
    model = MainStream(vocab_size)
    criterion = CtcLoss(opts, blank_id, device, reduction="none")

    # print(model)

    # Build trainer
    trainer = Trainer(opts, model, criterion, vocabulary, vocab_size, blank_id)

    if os.path.exists(opts.check_point):
        logging.info("Loading checkpoint file from {}".format(
            opts.check_point))
        epoch, num_updates, loss = trainer.load_checkpoint(opts.check_point)
    else:
        logging.info("No checkpoint file in found in {}".format(
            opts.check_point))
        epoch, num_updates, loss = 0, 0, 0.0

    trainer.set_num_updates(num_updates)
    model_manager = ModelManager(max_num_models=5)
    while epoch < opts.max_epoch and trainer.get_num_updates(
    ) < opts.max_updates:
        epoch += 1
        trainer.adjust_learning_rate(epoch)
        #trainer.dynamic_freeze_layers(epoch)
        loss = train(opts, train_datasets, valid_datasets, trainer, epoch,
                     num_updates, loss)

        #if num_updates % opts.save_interval_updates == 0:
        if epoch <= opts.stage_epoch * 2:
            phoenix_eval_err = eval(opts, valid_datasets, trainer, epoch)
            phoenix_eval_err = eval_tf(opts, valid_datasets, trainer, epoch)
        else:
            phoenix_eval_err = eval(opts, valid_datasets, trainer, epoch)
            phoenix_eval_err = eval_dec(opts, valid_datasets, trainer, epoch)

        save_ckpt = os.path.join(
            opts.log_dir, 'ep{:d}_{:.4f}.pkl'.format(epoch,
                                                     phoenix_eval_err[0]))
        trainer.save_checkpoint(save_ckpt, epoch, num_updates, loss)
        model_manager.update(save_ckpt, phoenix_eval_err, epoch)
Esempio n. 2
0
def main():
    opts = parse_args()
    setup_seed(opts.seed)
    init_logging(
        os.path.join(opts.log_dir,
                     '{:s}_seed{}_log.txt'.format(opts.task, opts.seed)))

    if torch.cuda.is_available():
        torch.cuda.set_device(opts.gpu)
        logging.info("Using GPU!")
        device = "cuda"
    else:
        logging.info("Using CPU!")
        device = "cpu"

    logging.info(opts)

    train_datasets = PhoenixVideo(opts.vocab_file,
                                  opts.corpus_dir,
                                  opts.video_path,
                                  phase="train",
                                  DEBUG=opts.DEBUG)
    valid_datasets = PhoenixVideo(opts.vocab_file,
                                  opts.corpus_dir,
                                  opts.video_path,
                                  phase="dev",
                                  DEBUG=opts.DEBUG)
    vocab_size = valid_datasets.vocab.num_words
    blank_id = valid_datasets.vocab.word2index['<BLANK>']
    vocabulary = Vocabulary(opts.vocab_file)
    model = MainStream(vocab_size, opts.bn_momentum)
    criterion = CtcLoss(opts, blank_id, device, reduction="none")

    logging.info(model)
    # Build trainer
    trainer = Trainer(opts, model, criterion, vocabulary, vocab_size, blank_id)

    if os.path.exists(opts.check_point):
        logging.info("Loading checkpoint file from {}".format(
            opts.check_point))
        epoch, num_updates, loss = trainer.load_checkpoint(opts.check_point)
    elif os.path.exists(opts.pretrain):
        logging.info("Loading checkpoint file from {}".format(opts.pretrain))
        trainer.pretrain(opts)
        epoch, num_updates, loss = 0, 0, 0.0
    else:
        logging.info("No checkpoint file in found in {}".format(
            opts.check_point))
        epoch, num_updates, loss = 0, 0, 0.0

    logging.info('| num. module params: {} (num. trained: {})'.format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))

    trainer.set_num_updates(num_updates)
    model_manager = ModelManager(max_num_models=25)
    while epoch < opts.max_epoch and trainer.get_num_updates(
    ) < opts.max_updates:
        epoch += 1
        trainer.adjust_learning_rate(epoch)
        loss = train(opts, train_datasets, valid_datasets, trainer, epoch,
                     num_updates, loss)

        if epoch <= opts.stage_epoch:
            eval_train(opts, train_datasets, trainer, epoch)
            # phoenix_eval_err = eval_tf(opts, valid_datasets, trainer, epoch)
            phoenix_eval_err = eval(opts, valid_datasets, trainer, epoch)
        else:
            # eval_train(opts, train_datasets, trainer, epoch)
            phoenix_eval_err = eval(opts, valid_datasets, trainer, epoch)

        save_ckpt = os.path.join(
            opts.log_dir, 'ep{:d}_{:.4f}.pkl'.format(epoch,
                                                     phoenix_eval_err[0]))
        trainer.save_checkpoint(save_ckpt, epoch, num_updates, loss)
        model_manager.update(save_ckpt, phoenix_eval_err, epoch)