def run_validation_bleu_score(model, SRC, TGT, valid_iter): translate = [] tgt = [] for i, batch in enumerate(valid_iter): src = batch.src.transpose(0, 1)[:1].cuda() src_mask = (src != SRC.vocab.stoi[BLANK_WORD]).unsqueeze(-2) out = greedy_decode(model, src, src_mask, max_len=100, start_symbol=TGT.vocab.stoi[BOS_WORD]) for k in range(out.size(0)): translate_str = [] for i in range(1, out.size(1)): sym = TGT.vocab.itos[out[k, i]] if sym == EOS_WORD: break translate_str.append(sym) tgt_str = [] for j in range(1, batch.trg.size(0)): sym = TGT.vocab.itos[batch.trg.data[j, k]] if sym == EOS_WORD: break tgt_str.append(sym) translate.append(translate_str) tgt.append(tgt_str) # Essential for sacrebleu calculations translation_sentences = [" ".join(x) for x in translate] target_sentences = [" ".join(x) for x in tgt] bleu_validation = evaluate_bleu(translation_sentences, target_sentences) print('Validation BLEU Score', bleu_validation) return bleu_validation
def __call__(self, x, y, norm): x = self.generator(x) loss = self.criterion(x.contiguous().view(-1, x.size(-1)), y.contiguous().view(-1)) / norm loss.backward() if self.opt is not None: self.opt.step() self.opt.optimizer.zero_grad() return loss.item() * norm if __name__ == '__main__': # Greedy decoding V = 11 criterion = LabelSmoothing(size=V, padding_idx=0, smoothing=0.0) model = make_model(V, V, n=2) model = model.cuda() model_opt = NoamOpt(model.src_embed[0].d_model, 1, 400, torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) for epoch in range(10): model.train() run_epoch(data_gen(V, 30, 20), model, SimpleLossCompute(model.generator, criterion, model_opt)) model.eval() print(run_epoch(data_gen(V, 30, 5), model, SimpleLossCompute(model.generator, criterion, None))) model.eval() src = Variable(torch.LongTensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])).cuda() src_mask = Variable(torch.ones(1, 1, 10)).cuda() print(greedy_decode(model, src, src_mask, max_len=10, start_symbol=1))
model_par.eval() loss = run_epoch((rebatch(pad_idx, b) for b in valid_iter), model_par, MultiGPULossCompute(model.generator, criterion, devices=devices, opt=None)) print(loss) else: model = torch.load('iwslt.pt') for i, batch in enumerate(valid_iter): src = batch.src.transpose(0, 1)[:1] src_mask = (src != SRC.vocab.stoi[BLANK_WORD]).unsqueeze(-2) out = greedy_decode(model, src, src_mask, max_len=60, start_symbol=TGT.vocab.stoi[BOS_WORD]) print('Translation:', end='\t') for i in range(1, out.size(1)): sym = TGT.vocab.itos[out[0, i]] if sym == EOS_WORD: break print(sym, end=' ') print() print('Target:', end='\t') for i in range(batch.trg.size(0)): sym = TGT.vocab.itos(batch.trg.data[i, 0]) if sym == EOS_WORD: break print(sym, end=' ')
def test(args): # TODO: Add testing configurations SRC = data.Field(tokenize=tokenize_de, pad_token=BLANK_WORD, lower=args.lower) TGT = data.Field(tokenize=tokenize_en, init_token=BOS_WORD, eos_token=EOS_WORD, pad_token=BLANK_WORD, lower=args.lower) # Load IWSLT Data ---> German to English Translation if args.dataset == 'IWSLT': train, val, test = datasets.IWSLT.splits( exts=('.de', '.en'), fields=(SRC, TGT), filter_pred=lambda x: len(vars(x)['src']) <= args.max_length and len(vars(x)['trg']) <= args.max_length) else: train, val, test = datasets.Multi30k.splits( exts=('.de', '.en'), fields=(SRC, TGT), filter_pred=lambda x: len(vars(x)['src']) <= args.max_length and len(vars(x)['trg']) <= args.max_length) # Frequency of words in the vocabulary SRC.build_vocab(train.src, min_freq=args.min_freq) TGT.build_vocab(train.trg, min_freq=args.min_freq) print('Running test...') print("Size of source vocabulary:", len(SRC.vocab)) print("Size of target vocabulary:", len(TGT.vocab)) model = make_model(len(SRC.vocab), len(TGT.vocab), n=args.num_blocks, d_model=args.hidden_dim, d_ff=args.ff_dim, h=args.num_heads, dropout=args.dropout) print("Model made with n:", args.num_blocks, "hidden_dim:", args.hidden_dim, "feed forward dim:", args.ff_dim, "heads:", args.num_heads, "dropout:", args.dropout) if args.load_model: print("Loading model from [%s]" % args.load_model) model.load_state_dict(torch.load(args.load_model)) model_parameters = filter(lambda p: p.requires_grad, model.parameters()) params = sum([np.prod(p.size()) for p in model_parameters]) print("Number of parameters: ", params) # feed_forward = [] # attn = [] # embed = [] # sublayer = [] # generator = [] # for name, param in model.named_parameters(): # if name.__contains__("feed_forward"): # feed_forward.append(np.prod(param.size())) # if name.__contains__("attn"): # attn.append(np.prod(param.size())) # if name.__contains__("embed"): # embed.append(np.prod(param.size())) # if name.__contains__("sublayer"): # sublayer.append(np.prod(param.size())) # if name.__contains__("generator"): # generator.append(np.prod(param.size())) feed_forward = [] # attn = [] # embed = [] # sublayer = [] # generator = [] for name, param in model.named_parameters(): if name.__contains__("embed") or name.__contains__("generator"): feed_forward.append(np.prod(param.size())) print("Num parameters:", np.sum(feed_forward)) # print("Num parameters in original attn layer", np.sum(attn)) # print("Num parameters in original embedding layer", np.sum(embed)) # print("Num parameters in original sublayer", np.sum(sublayer)) # print("Num parameters in original generator layer", np.sum(generator)) # pad_idx = TGT.vocab.stoi[BLANK_WORD] # criterion = LabelSmoothing(size=len(TGT.vocab), padding_idx=pad_idx, smoothing=0.1) # UNCOMMENT WHEN RUNNING ON RESEARCH MACHINES - run on GPU model.cuda() test_iter = MyIterator(test, batch_size=args.batch_size, device=0, repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)), batch_size_fn=batch_size_fn, train=False) # ## Post-Linear Quantization Code # overrides_yaml = """ # encoder.layers.*.self_attn.*: # bits_activations: null # bits_weights: null # bits_bias: null # encoder.layers.*.feed_forward.*: # bits_activations: 8 # bits_weights: 8 # bits_bias: 8 # encoder.layers.*.sublayer.*: # bits_activations: null # bits_weights: null # bits_bias: null # encoder.norm.*: # bits_activations: null # bits_weights: null # bits_bias: null # decoder.layers.*.self_attn.*: # bits_activations: null # bits_weights: null # bits_bias: null # decoder.layers.*.feed_forward.*: # bits_activations: 8 # bits_weights: 8 # bits_bias: 8 # decoder.layers.*.src_attn.*: # bits_activations: null # bits_weights: null # bits_bias: null # decoder.layers.*.sublayer.*: # bits_activations: null # bits_weights: null # bits_bias: null # decoder.norm.*: # bits_activations: null # bits_weights: null # bits_bias: null # src_embed.*: # bits_activations: null # bits_weights: null # bits_bias: null # tgt_embed.*: # bits_activations: null # bits_weights: null # bits_bias: null # generator.*: # bits_activations: null # bits_weights: null # bits_bias: null # """ # CREATE STATS FILE # distiller.utils.assign_layer_fq_names(model) # stats_file = './acts_quantization_stats.yaml' # # if not os.path.isfile(stats_file): # def eval_for_stats(model): # valid_iter = MyIterator(val, batch_size=args.batch_size, device=0, repeat=False, # sort_key=lambda x: (len(x.src), len(x.trg)), batch_size_fn=batch_size_fn, # train=False, # sort=False) # model.eval() # run_epoch((rebatch(pad_idx, b) for b in valid_iter), model, # MultiGPULossCompute(model.generator, criterion, devices=devices, opt=None), args, # SRC, TGT, valid_iter, is_valid=True) # collect_quant_stats(distiller.utils.make_non_parallel_copy(model), eval_for_stats, save_dir='.') # overrides = distiller.utils.yaml_ordered_load(overrides_yaml) # quantizer = PostTrainLinearQuantizer(deepcopy(model), mode="ASYMMETRIC_UNSIGNED", overrides=overrides) # Post-Linear Quantization block # dummy_input = (torch.ones(130, 10).to(dtype=torch.long), # torch.ones(130, 22).to(dtype=torch.long), # torch.ones(130, 1, 10).to(dtype=torch.long), # torch.ones(130, 22, 22).to(dtype=torch.long)) # quantizer.prepare_model(dummy_input) # model = quantizer.model model.eval() print(model) translate = [] tgt = [] start_infer_time = time.time() for k, batch in enumerate(test_iter): src_orig = batch.src.transpose(0, 1).cuda() trg_orig = batch.trg.transpose(0, 1) for m in range(0, len(src_orig), 1): src = src_orig[m:(m + 1)].cuda() trg = trg_orig[m:(m + 1)] src_mask = (src != SRC.vocab.stoi["<blank>"]).unsqueeze(-2) out = greedy_decode(model, src, src_mask, max_len=100, start_symbol=TGT.vocab.stoi["<s>"]) translate_str = [] for i in range(0, out.size(0)): for j in range(1, out.size(1)): sym = TGT.vocab.itos[out[i, j]] if sym == "</s>": break translate_str.append(sym) tgt_str = [] for i in range(trg.size(0)): for j in range(1, trg.size(1)): sym = TGT.vocab.itos[trg[i, j]] if sym == "</s>": break tgt_str.append(sym) translate.append(translate_str) tgt.append(tgt_str) print("Time for inference: ", time.time() - start_infer_time) # Essential for sacrebleu calculations translation_sentences = [" ".join(x) for x in translate] target_sentences = [" ".join(x) for x in tgt] bleu_validation = evaluate_bleu(translation_sentences, target_sentences) print('Test BLEU Score:', bleu_validation)
def train( train_path, val_path, save_path, n_layers = 6, model_dim = 512, feedforward_dim = 2048, n_heads = 8, dropout_rate = 0.1, n_epochs = 10, max_len = 60, min_freq = 10, max_val_outputs = 20): train, val, TGT, SRC, EOS_WORD, BOS_WORD, BLANK_WORD = get_dataset(train_path, val_path, min_freq) #torch.save(SRC.vocab, 'models/electronics/src_vocab.pt') #torch.save(TGT.vocab, 'models/electronics/trg_vocab.pt') SRC.vocab = torch.load('models/electronics/src_vocab.pt') TGT.vocab = torch.load('models/electronics/trg_vocab.pt') pad_idx = TGT.vocab.stoi[BLANK_WORD] # model = make_model(len(SRC.vocab), len(TGT.vocab), # n=n_layers, d_model=model_dim, # d_ff=feedforward_dim, h=n_heads, # dropout=dropout_rate) model = torch.load('models/electronics/electronics_autoencoder_epoch3.pt') model.cuda() criterion = LabelSmoothing(size=len(TGT.vocab), padding_idx=pad_idx, smoothing=0.1) criterion.cuda() BATCH_SIZE = 2048 # Was 12000, but I only have 12 GB RAM on my single GPU. train_iter = MyIterator(train, batch_size=BATCH_SIZE, device=0, repeat=False, #Faster with device warning sort_key=lambda x: (len(x.src), len(x.trg)), batch_size_fn=batch_size_fn, train=True) valid_iter = MyIterator(val, batch_size=BATCH_SIZE, device=0, repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)), batch_size_fn=batch_size_fn, train=False) model_par = nn.DataParallel(model, device_ids=devices) model_opt = NoamOpt(model.src_embed[0].d_model, 1, 2000, torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) for epoch in range(n_epochs): model_par.train() run_epoch((rebatch(pad_idx, b) for b in train_iter), model_par, MultiGPULossCompute(model.generator, criterion, devices=devices, opt=model_opt)) save_name = save_path + '_epoch' + str(epoch + 4) + '.pt' torch.save(model, save_name) model_par.eval() loss = run_epoch((rebatch(pad_idx, b) for b in valid_iter), model_par, MultiGPULossCompute(model.generator, criterion, devices=devices, opt=None)) print(loss) for i, batch in enumerate(valid_iter): if i > max_val_outputs: break src = batch.src.transpose(0, 1)[:1].cuda() src_mask = (src != SRC.vocab.stoi[BLANK_WORD]).unsqueeze(-2).cuda() out = greedy_decode(model, src, src_mask, max_len=max_len, start_symbol=TGT.vocab.stoi[BOS_WORD]) print('Translation:', end='\t') for i in range(1, out.size(1)): sym = TGT.vocab.itos[out[0, i]] if sym == EOS_WORD: break print(sym, end=' ') print() print('Target:', end='\t') for j in range(batch.trg.size(0)): sym = TGT.vocab.itos[batch.trg.data[j, 0]] if sym == EOS_WORD: break print(sym, end=' ') print()
model_prl, MultiGPULossCompute(model.generator, criterion, devices=args.devices, opt=None)) torch.save(model.state_dict(), 'models/params.pkl') sent = [] trans = "<s> " for batch_index, batch_data in enumerate(val_iter): src = batch_data.src.transpose(0, 1)[:1] src_mask = (src != src_pad_idx).unsqueeze(-2).to(args.device) out = greedy_decode(model, src, src_mask, max_len=60, start_symbol=TGT.vocab.stoi[args.BOS_TOKEN]) for i in src: sent.append(SRC.vocab.itos[i]) print("Translation:", end="\t") for i in range(1, out.size(1)): tok = TGT.vocab.itos[out[0, i]] if tok == args.EOS: break trans += tok + " " print(trans) print("\nTarget:", end="\t")