model.load_parameters(param_name, ctx)

static_alloc = True
model.hybridize(static_alloc=static_alloc)
logging.info(model)

# translator prepare
translator = BeamSearchTranslator(model=model, beam_size=args.beam_size,
                                  scorer=nlp.model.BeamSearchScorer(alpha=args.lp_alpha,
                                                                    K=args.lp_k),
                                  max_length=200)
logging.info('Use beam_size={}, alpha={}, K={}'.format(args.beam_size, args.lp_alpha, args.lp_k))

test_loss_function = MaskedSoftmaxCELoss()
test_loss_function.hybridize(static_alloc=static_alloc)

def inference():
    """inference function."""
    logging.info('Inference on test_dataset!')

    # data prepare
    test_data_loader = dataprocessor.get_dataloader(data_test, args,
                                                    dataset_type='test',
                                                    use_average_length=True)

    if args.bleu == 'tweaked':
        bpe = bool(args.dataset != 'IWSLT2015' and args.dataset != 'TOY')
        split_compound_word = bpe
        tokenized = True
    elif args.bleu == '13a' or args.bleu == 'intl':
Example #2
0
def main(_argv):

    os.makedirs(os.path.join('models', 'captioning', 'experiments',
                             FLAGS.model_id),
                exist_ok=True)

    if FLAGS.num_gpus > 0:  # only supports 1 GPU
        ctx = mx.gpu()
    else:
        ctx = mx.cpu()

    # Set up logging
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    log_file_path = os.path.join('models', 'captioning', 'experiments',
                                 FLAGS.model_id, 'log.txt')
    log_dir = os.path.dirname(log_file_path)
    if log_dir and not os.path.exists(log_dir):
        os.makedirs(log_dir)
    fh = logging.FileHandler(log_file_path)
    logger.addHandler(fh)

    key_flags = FLAGS.get_key_flags_for_module(sys.argv[0])
    logging.info('\n'.join(f.serialize() for f in key_flags))

    # set up tensorboard summary writer
    tb_sw = SummaryWriter(log_dir=os.path.join(log_dir, 'tb'),
                          comment=FLAGS.model_id)

    # are we using features or do we include the CNN?
    if FLAGS.feats_model is None:
        backbone_net = get_model(FLAGS.backbone, pretrained=True,
                                 ctx=ctx).features
        cnn_model = FrameModel(backbone_net,
                               11)  # hardcoded the number of classes
        if FLAGS.backbone_from_id:
            if os.path.exists(
                    os.path.join('models', 'vision', 'experiments',
                                 FLAGS.backbone_from_id)):
                files = os.listdir(
                    os.path.join('models', 'vision', 'experiments',
                                 FLAGS.backbone_from_id))
                files = [f for f in files if f[-7:] == '.params']
                if len(files) > 0:
                    files = sorted(files,
                                   reverse=True)  # put latest model first
                    model_name = files[0]
                    cnn_model.load_parameters(os.path.join(
                        'models', 'vision', 'experiments',
                        FLAGS.backbone_from_id, model_name),
                                              ctx=ctx)
                    logging.info('Loaded backbone params: {}'.format(
                        os.path.join('models', 'vision', 'experiments',
                                     FLAGS.backbone_from_id, model_name)))
            else:
                raise FileNotFoundError('{}'.format(
                    os.path.join('models', 'vision', 'experiments',
                                 FLAGS.backbone_from_id)))

        if FLAGS.freeze_backbone:
            for param in cnn_model.collect_params().values():
                param.grad_req = 'null'

        cnn_model = TimeDistributed(cnn_model.backbone)

        src_embed = cnn_model

        transform_train = transforms.Compose([
            transforms.RandomResizedCrop(FLAGS.data_shape),
            transforms.RandomFlipLeftRight(),
            transforms.RandomColorJitter(brightness=0.4,
                                         contrast=0.4,
                                         saturation=0.4),
            transforms.RandomLighting(0.1),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])

        transform_test = transforms.Compose([
            transforms.Resize(FLAGS.data_shape + 32),
            transforms.CenterCrop(FLAGS.data_shape),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])

    else:
        from mxnet.gluon import nn  # need to do this to force no use of Embedding on src
        src_embed = nn.HybridSequential(prefix='src_embed_')
        with src_embed.name_scope():
            src_embed.add(nn.Dropout(rate=0.0))

        transform_train = None
        transform_test = None

    # setup the data
    data_train = TennisSet(split='train',
                           transform=transform_train,
                           captions=True,
                           max_cap_len=FLAGS.tgt_max_len,
                           every=FLAGS.every,
                           feats_model=FLAGS.feats_model)
    data_val = TennisSet(split='val',
                         transform=transform_test,
                         captions=True,
                         vocab=data_train.vocab,
                         every=FLAGS.every,
                         inference=True,
                         feats_model=FLAGS.feats_model)
    data_test = TennisSet(split='test',
                          transform=transform_test,
                          captions=True,
                          vocab=data_train.vocab,
                          every=FLAGS.every,
                          inference=True,
                          feats_model=FLAGS.feats_model)

    val_tgt_sentences = data_val.get_captions(split=True)
    test_tgt_sentences = data_test.get_captions(split=True)
    write_sentences(
        val_tgt_sentences,
        os.path.join('models', 'captioning', 'experiments', FLAGS.model_id,
                     'val_gt.txt'))
    write_sentences(
        test_tgt_sentences,
        os.path.join('models', 'captioning', 'experiments', FLAGS.model_id,
                     'test_gt.txt'))

    # load embeddings for tgt_embed
    if FLAGS.emb_file:
        word_embs = nlp.embedding.TokenEmbedding.from_file(
            file_path=os.path.join('data', FLAGS.emb_file))
        data_train.vocab.set_embedding(word_embs)

        input_dim, output_dim = data_train.vocab.embedding.idx_to_vec.shape
        tgt_embed = gluon.nn.Embedding(input_dim, output_dim)
        tgt_embed.initialize(ctx=ctx)
        tgt_embed.weight.set_data(data_train.vocab.embedding.idx_to_vec)
    else:
        tgt_embed = None

    # setup the model
    encoder, decoder = get_gnmt_encoder_decoder(
        cell_type=FLAGS.cell_type,
        hidden_size=FLAGS.num_hidden,
        dropout=FLAGS.dropout,
        num_layers=FLAGS.num_layers,
        num_bi_layers=FLAGS.num_bi_layers)
    model = NMTModel(src_vocab=None,
                     tgt_vocab=data_train.vocab,
                     encoder=encoder,
                     decoder=decoder,
                     embed_size=FLAGS.emb_size,
                     prefix='gnmt_',
                     src_embed=src_embed,
                     tgt_embed=tgt_embed)

    model.initialize(init=mx.init.Uniform(0.1), ctx=ctx)
    static_alloc = True
    model.hybridize(static_alloc=static_alloc)
    logging.info(model)

    start_epoch = 0
    if os.path.exists(
            os.path.join('models', 'captioning', 'experiments',
                         FLAGS.model_id)):
        files = os.listdir(
            os.path.join('models', 'captioning', 'experiments',
                         FLAGS.model_id))
        files = [f for f in files if f[-7:] == '.params']
        if len(files) > 0:
            files = sorted(files, reverse=True)  # put latest model first
            model_name = files[0]
            if model_name == 'valid_best.params':
                model_name = files[1]
            start_epoch = int(model_name.split('.')[0]) + 1
            model.load_parameters(os.path.join('models', 'captioning',
                                               'experiments', FLAGS.model_id,
                                               model_name),
                                  ctx=ctx)
            logging.info('Loaded model params: {}'.format(
                os.path.join('models', 'captioning', 'experiments',
                             FLAGS.model_id, model_name)))

    # setup the beam search
    translator = BeamSearchTranslator(model=model,
                                      beam_size=FLAGS.beam_size,
                                      scorer=nlp.model.BeamSearchScorer(
                                          alpha=FLAGS.lp_alpha, K=FLAGS.lp_k),
                                      max_length=FLAGS.tgt_max_len + 100)
    logging.info('Use beam_size={}, alpha={}, K={}'.format(
        FLAGS.beam_size, FLAGS.lp_alpha, FLAGS.lp_k))

    # setup the loss function
    loss_function = MaskedSoftmaxCELoss()
    loss_function.hybridize(static_alloc=static_alloc)

    # run the training
    train(data_train, data_val, data_test, model, loss_function,
          val_tgt_sentences, test_tgt_sentences, translator, start_epoch, ctx,
          tb_sw)