model.load_parameters(param_name, ctx) static_alloc = True model.hybridize(static_alloc=static_alloc) logging.info(model) # translator prepare translator = BeamSearchTranslator(model=model, beam_size=args.beam_size, scorer=nlp.model.BeamSearchScorer(alpha=args.lp_alpha, K=args.lp_k), max_length=200) logging.info('Use beam_size={}, alpha={}, K={}'.format(args.beam_size, args.lp_alpha, args.lp_k)) test_loss_function = MaskedSoftmaxCELoss() test_loss_function.hybridize(static_alloc=static_alloc) def inference(): """inference function.""" logging.info('Inference on test_dataset!') # data prepare test_data_loader = dataprocessor.get_dataloader(data_test, args, dataset_type='test', use_average_length=True) if args.bleu == 'tweaked': bpe = bool(args.dataset != 'IWSLT2015' and args.dataset != 'TOY') split_compound_word = bpe tokenized = True elif args.bleu == '13a' or args.bleu == 'intl':
def main(_argv): os.makedirs(os.path.join('models', 'captioning', 'experiments', FLAGS.model_id), exist_ok=True) if FLAGS.num_gpus > 0: # only supports 1 GPU ctx = mx.gpu() else: ctx = mx.cpu() # Set up logging logging.basicConfig() logger = logging.getLogger() logger.setLevel(logging.INFO) log_file_path = os.path.join('models', 'captioning', 'experiments', FLAGS.model_id, 'log.txt') log_dir = os.path.dirname(log_file_path) if log_dir and not os.path.exists(log_dir): os.makedirs(log_dir) fh = logging.FileHandler(log_file_path) logger.addHandler(fh) key_flags = FLAGS.get_key_flags_for_module(sys.argv[0]) logging.info('\n'.join(f.serialize() for f in key_flags)) # set up tensorboard summary writer tb_sw = SummaryWriter(log_dir=os.path.join(log_dir, 'tb'), comment=FLAGS.model_id) # are we using features or do we include the CNN? if FLAGS.feats_model is None: backbone_net = get_model(FLAGS.backbone, pretrained=True, ctx=ctx).features cnn_model = FrameModel(backbone_net, 11) # hardcoded the number of classes if FLAGS.backbone_from_id: if os.path.exists( os.path.join('models', 'vision', 'experiments', FLAGS.backbone_from_id)): files = os.listdir( os.path.join('models', 'vision', 'experiments', FLAGS.backbone_from_id)) files = [f for f in files if f[-7:] == '.params'] if len(files) > 0: files = sorted(files, reverse=True) # put latest model first model_name = files[0] cnn_model.load_parameters(os.path.join( 'models', 'vision', 'experiments', FLAGS.backbone_from_id, model_name), ctx=ctx) logging.info('Loaded backbone params: {}'.format( os.path.join('models', 'vision', 'experiments', FLAGS.backbone_from_id, model_name))) else: raise FileNotFoundError('{}'.format( os.path.join('models', 'vision', 'experiments', FLAGS.backbone_from_id))) if FLAGS.freeze_backbone: for param in cnn_model.collect_params().values(): param.grad_req = 'null' cnn_model = TimeDistributed(cnn_model.backbone) src_embed = cnn_model transform_train = transforms.Compose([ transforms.RandomResizedCrop(FLAGS.data_shape), transforms.RandomFlipLeftRight(), transforms.RandomColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), transforms.RandomLighting(0.1), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ]) transform_test = transforms.Compose([ transforms.Resize(FLAGS.data_shape + 32), transforms.CenterCrop(FLAGS.data_shape), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ]) else: from mxnet.gluon import nn # need to do this to force no use of Embedding on src src_embed = nn.HybridSequential(prefix='src_embed_') with src_embed.name_scope(): src_embed.add(nn.Dropout(rate=0.0)) transform_train = None transform_test = None # setup the data data_train = TennisSet(split='train', transform=transform_train, captions=True, max_cap_len=FLAGS.tgt_max_len, every=FLAGS.every, feats_model=FLAGS.feats_model) data_val = TennisSet(split='val', transform=transform_test, captions=True, vocab=data_train.vocab, every=FLAGS.every, inference=True, feats_model=FLAGS.feats_model) data_test = TennisSet(split='test', transform=transform_test, captions=True, vocab=data_train.vocab, every=FLAGS.every, inference=True, feats_model=FLAGS.feats_model) val_tgt_sentences = data_val.get_captions(split=True) test_tgt_sentences = data_test.get_captions(split=True) write_sentences( val_tgt_sentences, os.path.join('models', 'captioning', 'experiments', FLAGS.model_id, 'val_gt.txt')) write_sentences( test_tgt_sentences, os.path.join('models', 'captioning', 'experiments', FLAGS.model_id, 'test_gt.txt')) # load embeddings for tgt_embed if FLAGS.emb_file: word_embs = nlp.embedding.TokenEmbedding.from_file( file_path=os.path.join('data', FLAGS.emb_file)) data_train.vocab.set_embedding(word_embs) input_dim, output_dim = data_train.vocab.embedding.idx_to_vec.shape tgt_embed = gluon.nn.Embedding(input_dim, output_dim) tgt_embed.initialize(ctx=ctx) tgt_embed.weight.set_data(data_train.vocab.embedding.idx_to_vec) else: tgt_embed = None # setup the model encoder, decoder = get_gnmt_encoder_decoder( cell_type=FLAGS.cell_type, hidden_size=FLAGS.num_hidden, dropout=FLAGS.dropout, num_layers=FLAGS.num_layers, num_bi_layers=FLAGS.num_bi_layers) model = NMTModel(src_vocab=None, tgt_vocab=data_train.vocab, encoder=encoder, decoder=decoder, embed_size=FLAGS.emb_size, prefix='gnmt_', src_embed=src_embed, tgt_embed=tgt_embed) model.initialize(init=mx.init.Uniform(0.1), ctx=ctx) static_alloc = True model.hybridize(static_alloc=static_alloc) logging.info(model) start_epoch = 0 if os.path.exists( os.path.join('models', 'captioning', 'experiments', FLAGS.model_id)): files = os.listdir( os.path.join('models', 'captioning', 'experiments', FLAGS.model_id)) files = [f for f in files if f[-7:] == '.params'] if len(files) > 0: files = sorted(files, reverse=True) # put latest model first model_name = files[0] if model_name == 'valid_best.params': model_name = files[1] start_epoch = int(model_name.split('.')[0]) + 1 model.load_parameters(os.path.join('models', 'captioning', 'experiments', FLAGS.model_id, model_name), ctx=ctx) logging.info('Loaded model params: {}'.format( os.path.join('models', 'captioning', 'experiments', FLAGS.model_id, model_name))) # setup the beam search translator = BeamSearchTranslator(model=model, beam_size=FLAGS.beam_size, scorer=nlp.model.BeamSearchScorer( alpha=FLAGS.lp_alpha, K=FLAGS.lp_k), max_length=FLAGS.tgt_max_len + 100) logging.info('Use beam_size={}, alpha={}, K={}'.format( FLAGS.beam_size, FLAGS.lp_alpha, FLAGS.lp_k)) # setup the loss function loss_function = MaskedSoftmaxCELoss() loss_function.hybridize(static_alloc=static_alloc) # run the training train(data_train, data_val, data_test, model, loss_function, val_tgt_sentences, test_tgt_sentences, translator, start_epoch, ctx, tb_sw)