def __init__(self, src_vocab, tgt_vocab, checkpoint, opts): self.src_vocab = src_vocab self.tgt_vocab = tgt_vocab hparams = checkpoint['hparams'] transformer = Transformer(len(src_vocab), len(tgt_vocab), hparams.max_len + 2, n_layers=hparams.n_layers, d_model=hparams.d_model, d_emb=hparams.d_model, d_hidden=hparams.d_hidden, n_heads=hparams.n_heads, d_k=hparams.d_k, d_v=hparams.d_v, dropout=hparams.dropout, pad_id=src_vocab.pad_id) transformer.load_state_dict(checkpoint['model']) log_proj = torch.nn.LogSoftmax() if hparams.cuda: transformer.cuda() log_proj.cuda() transformer.eval() self.hparams = hparams self.opts = opts self.model = transformer self.log_proj = log_proj
def create_model(opt): data = torch.load(opt.data_path) opt.src_vocab_size = len(data['src_dict']) opt.tgt_vocab_size = len(data['tgt_dict']) print('Creating new model parameters..') model = Transformer(opt) # Initialize a model state. model_state = {'opt': opt, 'curr_epochs': 0, 'train_steps': 0} # If opt.model_path exists, load model parameters. if os.path.exists(opt.model_path): print('Reloading model parameters..') model_state = torch.load(opt.model_path) model.load_state_dict(model_state['model_params']) if use_cuda: print('Using GPU..') model = model.cuda() return model, model_state
def __init__(self, opt, use_cuda): self.opt = opt self.use_cuda = use_cuda self.tt = torch.cuda if use_cuda else torch checkpoint = torch.load(opt.model_path) model_opt = checkpoint['opt'] self.model_opt = model_opt model = Transformer(model_opt) if use_cuda: print('Using GPU..') model = model.cuda() prob_proj = nn.LogSoftmax(dim=-1) model.load_state_dict(checkpoint['model_params']) print('Loaded pre-trained model_state..') self.model = model self.model.prob_proj = prob_proj self.model.eval()
def main(): args = parse_args() loader = DataLoader(MachineTranslationDataLoader, args.src, args.tgt, max_vocab_size=args.max_vocab_size, min_word_count=args.min_word_count, max_len=args.max_len, cuda=args.cuda) src_vocab, tgt_vocab = loader.loader.src.vocab, loader.loader.tgt_in.vocab print(len(src_vocab), len(tgt_vocab)) torch.save(src_vocab, os.path.join(args.logdir, 'src_vocab.pt')) torch.save(tgt_vocab, os.path.join(args.logdir, 'tgt_vocab.pt')) transformer = Transformer(len(src_vocab), len(tgt_vocab), args.max_len + 2, n_layers=args.n_layers, d_model=args.d_model, d_emb=args.d_model, d_hidden=args.d_hidden, n_heads=args.n_heads, d_k=args.d_k, d_v=args.d_v, dropout=args.dropout, pad_id=src_vocab.pad_id) weights = torch.ones(len(tgt_vocab)) weights[tgt_vocab.pad_id] = 0 optimizer = torch.optim.Adam(transformer.get_trainable_parameters(), lr=args.lr) loss_fn = torch.nn.CrossEntropyLoss(weights) if args.cuda: transformer = transformer.cuda() loss_fn = loss_fn.cuda() def loss_fn_wrap(src, tgt_in, tgt_out, src_pos, tgt_pos, logits): return loss_fn(logits, tgt_out.contiguous().view(-1)) def get_performance(gold, logits, pad_id): gold = gold.contiguous().view(-1) logits = logits.max(dim=1)[1] n_corrects = logits.data.eq(gold.data) n_corrects = n_corrects.masked_select(gold.ne(pad_id).data).sum() return n_corrects def epoch_fn(epoch, stats): (n_corrects, n_words) = list(zip(*[(x['n_corrects'], x['n_words']) for x in stats])) train_acc = sum(n_corrects) / sum(n_words) return {'train_acc': train_acc} def step_fn(step, src, tgt_in, tgt_out, src_pos, tgt_pos, logits): n_corrects = get_performance(tgt_out, logits, tgt_vocab.pad_id) n_words = tgt_out.data.ne(tgt_vocab.pad_id).sum() return {'n_corrects': n_corrects, 'n_words': n_words} trainer = Trainer(transformer, loss_fn_wrap, optimizer, logdir=args.logdir, hparams=args, save_mode=args.save_mode) trainer.train( lambda: loader.iter(batch_size=args.batch_size, with_pos=True), epochs=args.epochs, epoch_fn=epoch_fn, step_fn=step_fn, metric='train_acc')