Example #1
0
def load_model_and_field(device='cpu',
                         save_path='saved',
                         embedding_dim=512,
                         nhead=2,
                         max_seq_len=80,
                         max_pondering_time=10,
                         dropout=0.5):
    # load field
    cwd = os.path.abspath(__file__).replace('/translate.py', '')
    if os.path.exists(f'{cwd}/{save_path}/src.pickle') and os.path.exists(
            f'{cwd}/{save_path}/tgt.pickle'):
        print('loading saved fields...')
        with open(f'{cwd}/{save_path}/src.pickle', 'rb') as s:
            src_field = pickle.load(s)
        with open(f'{cwd}/{save_path}/tgt.pickle', 'rb') as t:
            tgt_field = pickle.load(t)
    else:
        print('creating fields...')
        src_field, tgt_field = create_field(max_seq_len, save_path)
    # load model
    model = UniversalTransformer(n_src_vocab=len(src_field.vocab),
                                 n_tgt_vocab=len(tgt_field.vocab),
                                 embedding_dim=embedding_dim,
                                 nhead=nhead,
                                 max_seq_len=max_seq_len,
                                 max_pondering_time=max_pondering_time)
    print('loading weights...')
    if device == 'cpu':
        model.load_state_dict(
            torch.load(f'{cwd}/{save_path}/model_state',
                       map_location=torch.device('cpu')))
    else:
        raise NotImplementedError('prediction on GPU is not implemented.')
    if device == 'cuda':
        model = model.cuda()
    return model, src_field, tgt_field, max_seq_len
Example #2
0
model_name = args.models_dir + '/' + args.prefix + hp_str

# build the model
model = UniversalTransformer(SRCs[0], TRG, args)

# logger.info(str(model))
if args.load_from is not None:
    with torch.cuda.device(args.gpu):
        model.load_state_dict(
            torch.load(args.models_dir + '/' + args.load_from + '.pt',
                       map_location=lambda storage, loc: storage.cuda())
        )  # load the pretrained models.

# use cuda
if args.gpu > -1:
    model.cuda(args.gpu)

# additional information
args.__dict__.update({
    'model_name': model_name,
    'hp_str': hp_str,
    'logger': logger,
    'n_lang': len(args.aux)
})

# tensorboard writer
if args.tensorboard and (not args.debug):
    from tensorboardX import SummaryWriter
    writer = SummaryWriter('{}/{}'.format(args.runs_dir,
                                          args.prefix + args.hp_str))
else:
Example #3
0
def main():
    # initialize variable
    parser = argparse.ArgumentParser(
        description='Initialize training parameter.')
    parser.add_argument('-device',
                        required=True,
                        type=str,
                        help='"cuda" or "cpu"')
    parser.add_argument('-save_path', type=str, default='saved')
    parser.add_argument('-use_saved_fields', action='store_true')
    parser.add_argument('-use_saved_weights', action='store_true')
    parser.add_argument('-epochs', type=int, default=10)
    parser.add_argument('-batch_size', type=int, default=3000)
    parser.add_argument('-max_seq_len', type=int, default=80)
    parser.add_argument('-max_pondering_time', type=int, default=10)
    parser.add_argument('-dropout', type=float, default=0.5)
    parser.add_argument('-learning_rate', type=float, default=0.0001)
    parser.add_argument('-nhead', type=int, default=2)
    parser.add_argument('-embedding_dim', type=int, default=512)
    parser.add_argument('-feedforward_dim', type=int, default=2048)
    parser.add_argument('-lr_scheduling', action='store_true')
    args = parser.parse_args()
    src_lang = 'en'
    tgt_lang = 'fr'
    # create train iterator (create field, dataset, iterator)
    # # create field
    cwd = os.path.abspath(__file__).replace('/train.py', '')
    if args.use_saved_fields:
        if args.device == 'cpu':
            print('loading saved fields...')
            with open(f'{cwd}/{args.save_path}/src.pickle', 'rb') as s:
                src_field = pickle.load(s)
            with open(f'{cwd}/{args.save_path}/tgt.pickle', 'rb') as t:
                tgt_field = pickle.load(t)
            print('end.')
        else:
            exit('use_saved_fields option can be used on only cpu.')
    else:
        print('creating fields...')
        src_field: torchtext.data.field.Field = torchtext.data.Field(
            lower=True, tokenize=Tokenize(src_lang))
        tgt_field: torchtext.data.field.Field = torchtext.data.Field(
            lower=True,
            tokenize=Tokenize(tgt_lang),
            init_token='<sos>',
            eos_token='<eos>')
        print('end.')
    # # create dataset
    print('creating dataset iterator...')
    src_data = open(f"{cwd}/data/english.txt").read().strip().split('\n')
    tgt_data = open(f"{cwd}/data/french.txt").read().strip().split('\n')
    df = pd.DataFrame({
        'src': src_data,
        'tgt': tgt_data
    },
                      columns=["src", "tgt"])
    too_long_mask = (df['src'].str.count(' ') < args.max_seq_len) & (
        df['tgt'].str.count(' ') < args.max_seq_len)
    df = df.loc[too_long_mask]  # remove too long sentence
    df.to_csv("tmp_dataset.csv", index=False)
    dataset = torchtext.data.TabularDataset('./tmp_dataset.csv',
                                            format='csv',
                                            fields=[('src', src_field),
                                                    ('tgt', tgt_field)])
    os.remove('tmp_dataset.csv')
    # # create itrerator
    dataset_iter = MyIterator(dataset,
                              batch_size=args.batch_size,
                              device=args.device,
                              repeat=False,
                              sort_key=lambda x: (len(x.src), len(x.tgt)),
                              batch_size_fn=batch_size_fn,
                              train=True,
                              shuffle=True)
    # build vocab, save field object and add variable.
    src_field.build_vocab(dataset)
    tgt_field.build_vocab(dataset)
    print('end.')
    if not args.use_saved_fields:
        print('saving fields...')
        pickle.dump(src_field, open(f'{cwd}/{args.save_path}/src.pickle',
                                    'wb'))
        pickle.dump(tgt_field, open(f'{cwd}/{args.save_path}/tgt.pickle',
                                    'wb'))
        print('end.')
    iteration_num = [i for i, _ in enumerate(dataset_iter)][-1]
    # initialize model
    model = UniversalTransformer(n_src_vocab=len(src_field.vocab),
                                 n_tgt_vocab=len(tgt_field.vocab),
                                 embedding_dim=args.embedding_dim,
                                 nhead=args.nhead,
                                 max_seq_len=args.max_seq_len,
                                 max_pondering_time=args.max_pondering_time)
    # initialize param
    if args.use_saved_weights:
        print('loading saved model states...')
        model.load_state_dict(
            torch.load(f'{cwd}/{args.save_path}/model_state'))
        print('end.')
    else:
        for param in model.parameters():
            if param.dim() > 1:
                nn.init.xavier_normal_(param)
    if args.device == 'cuda':
        model = model.cuda()
    # train model
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.learning_rate,
                                 betas=(0.9, 0.98),
                                 eps=1e-9)
    lr_scheduler = CosineAnnealingLR(optimizer, iteration_num)
    _train(model, dataset_iter, optimizer, lr_scheduler, args, src_field,
           tgt_field, iteration_num)
    print('saving weights...')
    torch.save(model.state_dict(), f'{cwd}/{args.save_path}/model_state')
    print('end.')