def load_model_and_field(device='cpu', save_path='saved', embedding_dim=512, nhead=2, max_seq_len=80, max_pondering_time=10, dropout=0.5): # load field cwd = os.path.abspath(__file__).replace('/translate.py', '') if os.path.exists(f'{cwd}/{save_path}/src.pickle') and os.path.exists( f'{cwd}/{save_path}/tgt.pickle'): print('loading saved fields...') with open(f'{cwd}/{save_path}/src.pickle', 'rb') as s: src_field = pickle.load(s) with open(f'{cwd}/{save_path}/tgt.pickle', 'rb') as t: tgt_field = pickle.load(t) else: print('creating fields...') src_field, tgt_field = create_field(max_seq_len, save_path) # load model model = UniversalTransformer(n_src_vocab=len(src_field.vocab), n_tgt_vocab=len(tgt_field.vocab), embedding_dim=embedding_dim, nhead=nhead, max_seq_len=max_seq_len, max_pondering_time=max_pondering_time) print('loading weights...') if device == 'cpu': model.load_state_dict( torch.load(f'{cwd}/{save_path}/model_state', map_location=torch.device('cpu'))) else: raise NotImplementedError('prediction on GPU is not implemented.') if device == 'cuda': model = model.cuda() return model, src_field, tgt_field, max_seq_len
model_name = args.models_dir + '/' + args.prefix + hp_str # build the model model = UniversalTransformer(SRCs[0], TRG, args) # logger.info(str(model)) if args.load_from is not None: with torch.cuda.device(args.gpu): model.load_state_dict( torch.load(args.models_dir + '/' + args.load_from + '.pt', map_location=lambda storage, loc: storage.cuda()) ) # load the pretrained models. # use cuda if args.gpu > -1: model.cuda(args.gpu) # additional information args.__dict__.update({ 'model_name': model_name, 'hp_str': hp_str, 'logger': logger, 'n_lang': len(args.aux) }) # tensorboard writer if args.tensorboard and (not args.debug): from tensorboardX import SummaryWriter writer = SummaryWriter('{}/{}'.format(args.runs_dir, args.prefix + args.hp_str)) else:
def main(): # initialize variable parser = argparse.ArgumentParser( description='Initialize training parameter.') parser.add_argument('-device', required=True, type=str, help='"cuda" or "cpu"') parser.add_argument('-save_path', type=str, default='saved') parser.add_argument('-use_saved_fields', action='store_true') parser.add_argument('-use_saved_weights', action='store_true') parser.add_argument('-epochs', type=int, default=10) parser.add_argument('-batch_size', type=int, default=3000) parser.add_argument('-max_seq_len', type=int, default=80) parser.add_argument('-max_pondering_time', type=int, default=10) parser.add_argument('-dropout', type=float, default=0.5) parser.add_argument('-learning_rate', type=float, default=0.0001) parser.add_argument('-nhead', type=int, default=2) parser.add_argument('-embedding_dim', type=int, default=512) parser.add_argument('-feedforward_dim', type=int, default=2048) parser.add_argument('-lr_scheduling', action='store_true') args = parser.parse_args() src_lang = 'en' tgt_lang = 'fr' # create train iterator (create field, dataset, iterator) # # create field cwd = os.path.abspath(__file__).replace('/train.py', '') if args.use_saved_fields: if args.device == 'cpu': print('loading saved fields...') with open(f'{cwd}/{args.save_path}/src.pickle', 'rb') as s: src_field = pickle.load(s) with open(f'{cwd}/{args.save_path}/tgt.pickle', 'rb') as t: tgt_field = pickle.load(t) print('end.') else: exit('use_saved_fields option can be used on only cpu.') else: print('creating fields...') src_field: torchtext.data.field.Field = torchtext.data.Field( lower=True, tokenize=Tokenize(src_lang)) tgt_field: torchtext.data.field.Field = torchtext.data.Field( lower=True, tokenize=Tokenize(tgt_lang), init_token='<sos>', eos_token='<eos>') print('end.') # # create dataset print('creating dataset iterator...') src_data = open(f"{cwd}/data/english.txt").read().strip().split('\n') tgt_data = open(f"{cwd}/data/french.txt").read().strip().split('\n') df = pd.DataFrame({ 'src': src_data, 'tgt': tgt_data }, columns=["src", "tgt"]) too_long_mask = (df['src'].str.count(' ') < args.max_seq_len) & ( df['tgt'].str.count(' ') < args.max_seq_len) df = df.loc[too_long_mask] # remove too long sentence df.to_csv("tmp_dataset.csv", index=False) dataset = torchtext.data.TabularDataset('./tmp_dataset.csv', format='csv', fields=[('src', src_field), ('tgt', tgt_field)]) os.remove('tmp_dataset.csv') # # create itrerator dataset_iter = MyIterator(dataset, batch_size=args.batch_size, device=args.device, repeat=False, sort_key=lambda x: (len(x.src), len(x.tgt)), batch_size_fn=batch_size_fn, train=True, shuffle=True) # build vocab, save field object and add variable. src_field.build_vocab(dataset) tgt_field.build_vocab(dataset) print('end.') if not args.use_saved_fields: print('saving fields...') pickle.dump(src_field, open(f'{cwd}/{args.save_path}/src.pickle', 'wb')) pickle.dump(tgt_field, open(f'{cwd}/{args.save_path}/tgt.pickle', 'wb')) print('end.') iteration_num = [i for i, _ in enumerate(dataset_iter)][-1] # initialize model model = UniversalTransformer(n_src_vocab=len(src_field.vocab), n_tgt_vocab=len(tgt_field.vocab), embedding_dim=args.embedding_dim, nhead=args.nhead, max_seq_len=args.max_seq_len, max_pondering_time=args.max_pondering_time) # initialize param if args.use_saved_weights: print('loading saved model states...') model.load_state_dict( torch.load(f'{cwd}/{args.save_path}/model_state')) print('end.') else: for param in model.parameters(): if param.dim() > 1: nn.init.xavier_normal_(param) if args.device == 'cuda': model = model.cuda() # train model optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, betas=(0.9, 0.98), eps=1e-9) lr_scheduler = CosineAnnealingLR(optimizer, iteration_num) _train(model, dataset_iter, optimizer, lr_scheduler, args, src_field, tgt_field, iteration_num) print('saving weights...') torch.save(model.state_dict(), f'{cwd}/{args.save_path}/model_state') print('end.')