예제 #1
0
def train_model(
    train_source,
    train_target,
    dev_source,
    dev_target,
    experiment_directory,
    resume=False,
):
    # Prepare dataset
    train = Seq2SeqDataset.from_file(train_source, train_target)
    train.build_vocab(300, 6000)
    dev = Seq2SeqDataset.from_file(
        dev_source,
        dev_target,
        share_fields_from=train,
    )
    input_vocab = train.src_field.vocab
    output_vocab = train.tgt_field.vocab

    # Prepare loss
    weight = torch.ones(len(output_vocab))
    pad = output_vocab.stoi[train.tgt_field.pad_token]
    loss = Perplexity(weight, pad)
    if torch.cuda.is_available():
        loss.cuda()

    seq2seq = None
    optimizer = None
    if not resume:
        seq2seq, optimizer, scheduler = initialize_model(
            train, input_vocab, output_vocab)

    # Train
    trainer = SupervisedTrainer(
        loss=loss,
        batch_size=32,
        checkpoint_every=50,
        print_every=10,
        experiment_directory=experiment_directory,
    )
    start = time.clock()
    try:
        seq2seq = trainer.train(
            seq2seq,
            train,
            n_epochs=10,
            dev_data=dev,
            optimizer=optimizer,
            teacher_forcing_ratio=0.5,
            resume=resume,
        )
    # Capture ^C
    except KeyboardInterrupt:
        pass
    end = time.clock() - start
    logging.info('Training time: %.2fs', end)

    return seq2seq, input_vocab, output_vocab
예제 #2
0
def build_model(src, tgt, hidden_size, mini_batch_size, bidirectional, dropout,
                attention, init_value):
    EXPERIMENT.param("Hidden", hidden_size)
    EXPERIMENT.param("Bidirectional", bidirectional)
    EXPERIMENT.param("Dropout", dropout)
    EXPERIMENT.param("Attention", attention)
    EXPERIMENT.param("Mini-batch", mini_batch_size)
    weight = torch.ones(len(tgt.vocab))
    pad = tgt.vocab.stoi[tgt.pad_token]
    loss = Perplexity(weight, pad)
    encoder = EncoderRNN(len(src.vocab),
                         MAX_LEN,
                         hidden_size,
                         rnn_cell="lstm",
                         bidirectional=bidirectional,
                         dropout_p=dropout,
                         variable_lengths=False)
    decoder = DecoderRNN(
        len(tgt.vocab),
        MAX_LEN,
        hidden_size,  # * 2 if bidirectional else hidden_size,
        rnn_cell="lstm",
        use_attention=attention,
        eos_id=tgt.eos_id,
        sos_id=tgt.sos_id)
    seq2seq = Seq2seq(encoder, decoder)
    using_cuda = False
    if torch.cuda.is_available():
        using_cuda = True
        encoder.cuda()
        decoder.cuda()
        seq2seq.cuda()
        loss.cuda()
    EXPERIMENT.param("CUDA", using_cuda)
    for param in seq2seq.parameters():
        param.data.uniform_(-init_value, init_value)

    trainer = SupervisedTrainer(loss=loss,
                                batch_size=mini_batch_size,
                                checkpoint_every=5000,
                                random_seed=42,
                                print_every=1000)
    return seq2seq, trainer
예제 #3
0
def load_model_data_evaluator(expt_dir, model_name, data_path, batch_size=128):
    checkpoint_path = os.path.join(expt_dir, Checkpoint.CHECKPOINT_DIR_NAME, model_name)
    checkpoint = Checkpoint.load(checkpoint_path)
    model = checkpoint.model
    input_vocab = checkpoint.input_vocab
    output_vocab = checkpoint.output_vocab

    data_all, data_sml, data_med, data_lrg, fields_inp, src, tgt, src_adv, idx_field = load_data(data_path)

    src.vocab = input_vocab
    tgt.vocab = output_vocab
    src_adv.vocab = input_vocab

    weight = torch.ones(len(tgt.vocab))
    pad = tgt.vocab.stoi[tgt.pad_token]
    loss = Perplexity(weight, pad)
    if torch.cuda.is_available():
        loss.cuda()
    evaluator = Evaluator(loss=loss, batch_size=batch_size)

    return model, data_all, data_sml, data_med, data_lrg, evaluator, fields_inp
예제 #4
0
def pretrain_generator(model, train, dev):
    # pre-train generator
    weight = torch.ones(len(tgt.vocab))
    pad = tgt.vocab.stoi[tgt.pad_token]
    loss = Perplexity(weight, pad)
    if torch.cuda.is_available():
        loss.cuda()

    optimizer = Optimizer(torch.optim.Adam(gen.parameters()), max_grad_norm=5)
    scheduler = StepLR(optimizer.optimizer, 1)
    optimizer.set_scheduler(scheduler)

    supervised = SupervisedTrainer(loss=loss,
                                   batch_size=32,
                                   random_seed=random_seed,
                                   expt_dir=expt_gen_dir)
    supervised.train(model,
                     train,
                     num_epochs=20,
                     dev_data=dev,
                     optimizer=optimizer,
                     teacher_forcing_ratio=0,
                     resume=resume)
예제 #5
0
    # inputs = torchtext.Field(lower=True, include_lengths=True, batch_first=True)
    # inputs.build_vocab(src.vocab)
    src.vocab.load_vectors(wv_type='glove.840B', wv_dim=300)

    # NOTE: If the source field name and the target field name
    # are different from 'src' and 'tgt' respectively, they have
    # to be set explicitly before any training or inference
    # seq2seq.src_field_name = 'src'
    # seq2seq.tgt_field_name = 'tgt'

    # Prepare loss
    weight = torch.ones(len(tgt.vocab))
    pad = tgt.vocab.stoi[tgt.pad_token]
    loss = Perplexity(weight, pad)
    if torch.cuda.is_available():
        loss.cuda()

    seq2seq = None
    optimizer = None
    if not opt.resume:
        # Initialize model
        # hidden_size=128
        hidden_size = 300
        bidirectional = True

        encoder = EncoderRNN(len(src.vocab), max_len, hidden_size,
                             bidirectional=bidirectional, variable_lengths=True)
        decoder = DecoderRNN(len(tgt.vocab), max_len, hidden_size * 2 if bidirectional else 1,
                             dropout_p=0.2, use_attention=True, bidirectional=bidirectional,
                             eos_id=tgt.eos_id, sos_id=tgt.sos_id)
        seq2seq = Seq2seq(encoder, decoder)
예제 #6
0
def run_training(opt, default_data_dir, num_epochs=100):
    if opt.load_checkpoint is not None:
        logging.info("loading checkpoint from {}".format(
            os.path.join(opt.expt_dir, Checkpoint.CHECKPOINT_DIR_NAME, opt.load_checkpoint)))
        checkpoint_path = os.path.join(opt.expt_dir, Checkpoint.CHECKPOINT_DIR_NAME, opt.load_checkpoint)
        checkpoint = Checkpoint.load(checkpoint_path)
        seq2seq = checkpoint.model
        input_vocab = checkpoint.input_vocab
        output_vocab = checkpoint.output_vocab
    else:

        # Prepare dataset
        src = SourceField()
        tgt = TargetField()
        max_len = 50

        data_file = os.path.join(default_data_dir, opt.train_path, 'data.txt')

        logging.info("Starting new Training session on %s", data_file)

        def len_filter(example):
            return (len(example.src) <= max_len) and (len(example.tgt) <= max_len) \
                   and (len(example.src) > 0) and (len(example.tgt) > 0)

        train = torchtext.data.TabularDataset(
            path=data_file, format='json',
            fields={'src': ('src', src), 'tgt': ('tgt', tgt)},
            filter_pred=len_filter
        )

        dev = None
        if opt.no_dev is False:
            dev_data_file = os.path.join(default_data_dir, opt.train_path, 'dev-data.txt')
            dev = torchtext.data.TabularDataset(
                path=dev_data_file, format='json',
                fields={'src': ('src', src), 'tgt': ('tgt', tgt)},
                filter_pred=len_filter
            )

        src.build_vocab(train, max_size=50000)
        tgt.build_vocab(train, max_size=50000)
        input_vocab = src.vocab
        output_vocab = tgt.vocab

        # NOTE: If the source field name and the target field name
        # are different from 'src' and 'tgt' respectively, they have
        # to be set explicitly before any training or inference
        # seq2seq.src_field_name = 'src'
        # seq2seq.tgt_field_name = 'tgt'

        # Prepare loss
        weight = torch.ones(len(tgt.vocab))
        pad = tgt.vocab.stoi[tgt.pad_token]
        loss = Perplexity(weight, pad)
        if torch.cuda.is_available():
            logging.info("Yayyy We got CUDA!!!")
            loss.cuda()
        else:
            logging.info("No cuda available device found running on cpu")

        seq2seq = None
        optimizer = None
        if not opt.resume:
            hidden_size = 128
            decoder_hidden_size = hidden_size * 2
            logging.info("EncoderRNN Hidden Size: %s", hidden_size)
            logging.info("DecoderRNN Hidden Size: %s", decoder_hidden_size)
            bidirectional = True
            encoder = EncoderRNN(len(src.vocab), max_len, hidden_size,
                                 bidirectional=bidirectional,
                                 rnn_cell='lstm',
                                 variable_lengths=True)
            decoder = DecoderRNN(len(tgt.vocab), max_len, decoder_hidden_size,
                                 dropout_p=0, use_attention=True,
                                 bidirectional=bidirectional,
                                 rnn_cell='lstm',
                                 eos_id=tgt.eos_id, sos_id=tgt.sos_id)

            seq2seq = Seq2seq(encoder, decoder)
            if torch.cuda.is_available():
                seq2seq.cuda()

            for param in seq2seq.parameters():
                param.data.uniform_(-0.08, 0.08)

        # Optimizer and learning rate scheduler can be customized by
        # explicitly constructing the objects and pass to the trainer.

        optimizer = Optimizer(torch.optim.Adam(seq2seq.parameters()), max_grad_norm=5)
        scheduler = StepLR(optimizer.optimizer, 1)
        optimizer.set_scheduler(scheduler)

        # train

        num_epochs = num_epochs
        batch_size = 32
        checkpoint_every = num_epochs / 10
        print_every = num_epochs / 100

        properties = dict(batch_size=batch_size,
                          checkpoint_every=checkpoint_every,
                          print_every=print_every, expt_dir=opt.expt_dir,
                          num_epochs=num_epochs,
                          teacher_forcing_ratio=0.5,
                          resume=opt.resume)

        logging.info("Starting training with the following Properties %s", json.dumps(properties, indent=2))
        t = SupervisedTrainer(loss=loss, batch_size=num_epochs,
                              checkpoint_every=checkpoint_every,
                              print_every=print_every, expt_dir=opt.expt_dir)

        seq2seq = t.train(seq2seq, train,
                          num_epochs=num_epochs, dev_data=dev,
                          optimizer=optimizer,
                          teacher_forcing_ratio=0.5,
                          resume=opt.resume)

        evaluator = Evaluator(loss=loss, batch_size=batch_size)

        if opt.no_dev is False:
            dev_loss, accuracy = evaluator.evaluate(seq2seq, dev)
            logging.info("Dev Loss: %s", dev_loss)
            logging.info("Accuracy: %s", dev_loss)

    beam_search = Seq2seq(seq2seq.encoder, TopKDecoder(seq2seq.decoder, 4))

    predictor = Predictor(beam_search, input_vocab, output_vocab)
    while True:
        try:
            seq_str = raw_input("Type in a source sequence:")
            seq = seq_str.strip().split()
            results = predictor.predict_n(seq, n=3)
            for i, res in enumerate(results):
                print('option %s: %s\n', i + 1, res)
        except KeyboardInterrupt:
            logging.info("Bye Bye")
            exit(0)
예제 #7
0
파일: textsum.py 프로젝트: ffiamz/tmp
def train():
    src = SourceField(sequential=True,
                      tokenize=lambda x: [i for i in jieba.lcut(x)])
    tgt = TargetField(sequential=True,
                      tokenize=lambda x: [i for i in jieba.lcut(x)])
    max_len = 50

    def len_filter(example):
        return len(example.src) <= max_len and len(example.tgt) <= max_len

    train = torchtext.data.TabularDataset(path=opt.train_path,
                                          format='csv',
                                          fields=[('src', src), ('tgt', tgt)],
                                          filter_pred=len_filter)
    dev = torchtext.data.TabularDataset(path=opt.dev_path,
                                        format='csv',
                                        fields=[('src', src), ('tgt', tgt)],
                                        filter_pred=len_filter)

    src.build_vocab(train, max_size=50000)
    tgt.build_vocab(train, max_size=50000)
    input_vocab = src.vocab
    output_vocab = tgt.vocab

    # NOTE: If the source field name and the target field name
    # are different from 'src' and 'tgt' respectively, they have
    # to be set explicitly before any training or inference
    # seq2seq.src_field_name = 'src'
    # seq2seq.tgt_field_name = 'tgt'

    # Prepare loss
    weight = torch.ones(len(tgt.vocab))
    pad = tgt.vocab.stoi[tgt.pad_token]
    loss = Perplexity(weight, pad)
    if torch.cuda.is_available():
        loss.cuda()

    seq2seq = None
    optimizer = None
    if not opt.resume:
        # Initialize model
        hidden_size = 128
        bidirectional = True
        encoder = EncoderRNN(len(src.vocab),
                             max_len,
                             hidden_size,
                             bidirectional=bidirectional,
                             variable_lengths=True)
        decoder = DecoderRNN(len(tgt.vocab),
                             max_len,
                             hidden_size * 2 if bidirectional else hidden_size,
                             dropout_p=0.2,
                             use_attention=True,
                             bidirectional=bidirectional,
                             eos_id=tgt.eos_id,
                             sos_id=tgt.sos_id)
        seq2seq = Seq2seq(encoder, decoder)
        if torch.cuda.is_available():
            seq2seq.cuda()

        for param in seq2seq.parameters():
            param.data.uniform_(-0.08, 0.08)

        # Optimizer and learning rate scheduler can be customized by
        # explicitly constructing the objects and pass to the trainer.
        #
        # optimizer = Optimizer(torch.optim.Adam(seq2seq.parameters()), max_grad_norm=5)
        # scheduler = StepLR(optimizer.optimizer, 1)
        # optimizer.set_scheduler(scheduler)

    # train
    t = SupervisedTrainer(loss=loss,
                          batch_size=32,
                          checkpoint_every=50,
                          print_every=10,
                          expt_dir=opt.expt_dir)

    seq2seq = t.train(seq2seq,
                      train,
                      num_epochs=6,
                      dev_data=dev,
                      optimizer=optimizer,
                      teacher_forcing_ratio=0.5,
                      resume=opt.resume)
    predictor = Predictor(seq2seq, input_vocab, output_vocab)
예제 #8
0
    # input_vocab = src.vocab
    # output_vocab = tgt.vocab

src_adv.vocab = src.vocab

logging.info('Indices of special replace tokens:\n')
for rep in replace_tokens:
    logging.info("%s, %d; " % (rep, src.vocab.stoi[rep]))
logging.info('\n')

# Prepare loss
weight = torch.ones(len(tgt.vocab))
pad = tgt.vocab.stoi[tgt.pad_token]
loss = Perplexity(weight, pad)
if torch.cuda.is_available():
    loss.cuda()

batch_adv_loss = Perplexity(weight, pad)
if torch.cuda.is_available():
    batch_adv_loss.cuda()

# seq2seq = None
optimizer = None
if not opt.resume:
    # Initialize model
    hidden_size = params['hidden_size']
    bidirectional = True
    encoder = EncoderRNN(len(src.vocab),
                         max_len,
                         hidden_size,
                         bidirectional=bidirectional,
예제 #9
0
def apply_gradient_attack(data, model, input_vocab, replace_tokens, field_name,
                          opt):
    def convert_to_onehot(inp, vocab_size):
        return torch.zeros(inp.size(0), inp.size(1), vocab_size,
                           device=device).scatter_(2, inp.unsqueeze(2), 1.)

    batch_iterator = torchtext.data.BucketIterator(
        dataset=data,
        batch_size=opt.batch_size,
        sort=True,
        sort_within_batch=True,
        sort_key=lambda x: len(x.src),
        device=device,
        repeat=False)
    batch_generator = batch_iterator.__iter__()

    weight = torch.ones(len(tgt.vocab)).half()
    pad = tgt.vocab.stoi[tgt.pad_token]
    loss = Perplexity(weight, pad)
    if torch.cuda.is_available():
        loss.cuda()
    model.train()

    d = {}

    for batch in tqdm.tqdm(batch_generator, total=len(batch_iterator)):
        indices = getattr(batch, 'index')
        input_variables, input_lengths = getattr(batch, field_name)
        target_variables = getattr(batch, 'tgt')

        # Do random attack if inputs are too long and will OOM under gradient attack
        if max(getattr(batch, field_name)[1]) > 250:
            rand_replacements = get_random_token_replacement(
                input_variables.cpu().numpy(), input_vocab,
                indices.cpu().numpy(), replace_tokens, opt.distinct)

            d.update(rand_replacements)
            continue

        # convert input_variables to one_hot
        input_onehot = Variable(convert_to_onehot(input_variables,
                                                  vocab_size=len(input_vocab)),
                                requires_grad=True).half()

        # Forward propagation
        decoder_outputs, decoder_hidden, other = model(input_onehot,
                                                       input_lengths,
                                                       target_variables,
                                                       already_one_hot=True)

        # print outputs for debugging
        # for i,output_seq_len in enumerate(other['length']):
        #	print(i,output_seq_len)
        #	tgt_id_seq = [other['sequence'][di][i].data[0] for di in range(output_seq_len)]
        #	tgt_seq = [output_vocab.itos[tok] for tok in tgt_id_seq]
        #	print(' '.join([x for x in tgt_seq if x not in ['<sos>','<eos>','<pad>']]), end=', ')
        #	gt = [output_vocab.itos[tok] for tok in target_variables[i]]
        #	print(' '.join([x for x in gt if x not in ['<sos>','<eos>','<pad>']]))

        # Get loss
        loss.reset()
        for step, step_output in enumerate(decoder_outputs):
            batch_size = target_variables.size(0)
            loss.eval_batch(step_output.contiguous().view(batch_size, -1),
                            target_variables[:, step + 1])
        # Backward propagation
        model.zero_grad()
        input_onehot.retain_grad()
        loss.backward(retain_graph=True)
        grads = input_onehot.grad
        del input_onehot

        best_replacements = get_best_token_replacement(
            input_variables.cpu().numpy(),
            grads.cpu().numpy(), input_vocab,
            indices.cpu().numpy(), replace_tokens, opt.distinct)

        d.update(best_replacements)

    return d
예제 #10
0
            opt.dev_path, opt.mean_std_path, opt.max_len, opt.min_len, ixtoword, wordtoix)

    weight = torch.ones(len(train_label_lang.word2index))
    for word in train_label_lang.word2index:
        if train_label_lang.word2count[word] == 0:
            continue
        index = train_label_lang.word2index[word]
        weight[index] = weight[index] * opt.count_smooth / float(
            math.pow(train_label_lang.word2count[word], 0.8))

    # Prepare loss
    pad = train_label_lang.word2index["<pad>"]
    lloss = Perplexity(weight, pad, opt.lamda1)
    bloss = BBLoss(opt.batch_size, opt.gmm_comp_num, opt.lamda2)
    if torch.cuda.is_available():
        lloss.cuda()
        bloss.cuda()

    print('train_label_lang.index2word:')
    for index in train_label_lang.index2word:
        print('{} : {} '.format(index, train_label_lang.index2word[index]))

    print('train_label_lang.word2count:')
    for word in train_label_lang.word2count:
        print('{} : {} '.format(word, train_label_lang.word2count[word]))

    hidden_size = opt.embedding_dim
    encoder = PreEncoderRNN(train_cap_lang.n_words, nhidden=opt.embedding_dim)
    state_dict = torch.load(opt.encoder_path,
                            map_location=lambda storage, loc: storage)
    encoder.load_state_dict(state_dict)
예제 #11
0
class auto_seq2seq:
    def __init__(self,
                 data_path,
                 model_save_path,
                 model_load_path,
                 hidden_size=32,
                 max_vocab=4000,
                 device='cuda'):
        self.src = SourceField()
        self.tgt = TargetField()
        self.max_length = 90
        self.data_path = data_path
        self.model_save_path = model_save_path
        self.model_load_path = model_load_path

        def len_filter(example):
            return len(example.src) <= self.max_length and len(
                example.tgt) <= self.max_length

        self.trainset = torchtext.data.TabularDataset(
            path=os.path.join(self.data_path, 'train'),
            format='tsv',
            fields=[('src', self.src), ('tgt', self.tgt)],
            filter_pred=len_filter)
        self.devset = torchtext.data.TabularDataset(path=os.path.join(
            self.data_path, 'eval'),
                                                    format='tsv',
                                                    fields=[('src', self.src),
                                                            ('tgt', self.tgt)],
                                                    filter_pred=len_filter)
        self.src.build_vocab(self.trainset, max_size=max_vocab)
        self.tgt.build_vocab(self.trainset, max_size=max_vocab)
        weight = torch.ones(len(self.tgt.vocab))
        pad = self.tgt.vocab.stoi[self.tgt.pad_token]
        self.loss = Perplexity(weight, pad)
        self.loss.cuda()
        self.optimizer = None
        self.hidden_size = hidden_size
        self.bidirectional = True
        encoder = EncoderRNN(len(self.src.vocab),
                             self.max_length,
                             self.hidden_size,
                             bidirectional=self.bidirectional,
                             variable_lengths=True)
        decoder = DecoderRNN(len(self.tgt.vocab),
                             self.max_length,
                             self.hidden_size *
                             2 if self.bidirectional else self.hidden_size,
                             dropout_p=0.2,
                             use_attention=True,
                             bidirectional=self.bidirectional,
                             eos_id=self.tgt.eos_id,
                             sos_id=self.tgt.sos_id)
        self.device = device
        self.seq2seq = Seq2seq(encoder, decoder).cuda()
        for param in self.seq2seq.parameters():
            param.data.uniform_(-0.08, 0.08)

    def train(self, epoch=20, resume=False):
        t = SupervisedTrainer(loss=self.loss,
                              batch_size=96,
                              checkpoint_every=1000,
                              print_every=1000,
                              expt_dir=self.model_save_path)
        self.seq2seq = t.train(self.seq2seq,
                               self.trainset,
                               num_epochs=epoch,
                               dev_data=self.devset,
                               optimizer=self.optimizer,
                               teacher_forcing_ratio=0.5,
                               resume=resume)
예제 #12
0
파일: train.py 프로젝트: HaoTse/seq2seq-sum
def main():
    ''' Main function '''
    parser = argparse.ArgumentParser()

    parser.add_argument('-data', required=True)

    parser.add_argument('-epoch', type=int, default=3)
    parser.add_argument('-batch_size', type=int, default=64)

    parser.add_argument('-d_model', type=int, default=1024)
    parser.add_argument('-n_layer', type=int, default=1)

    parser.add_argument('-dropout', type=float, default=0)

    parser.add_argument('-log', default=None)
    parser.add_argument('-save_model', default=None)
    parser.add_argument('-save_mode',
                        type=str,
                        choices=['all', 'best'],
                        default='best')

    parser.add_argument('-seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")

    parser.add_argument('-no_cuda', action='store_true')
    parser.add_argument('-teacher_forcing_ratio', type=float, default=0.5)

    opt = parser.parse_args()
    opt.cuda = not opt.no_cuda
    opt.d_word_vec = opt.d_model
    opt.log = opt.save_model

    random.seed(opt.seed)
    np.random.seed(opt.seed)
    torch.manual_seed(opt.seed)
    if opt.cuda:
        torch.cuda.manual_seed_all(opt.seed)

    #========= Loading Dataset =========#
    data = torch.load(opt.data)
    opt.max_token_seq_len = data['settings'].max_token_seq_len

    training_data, validation_data = prepare_dataloaders(data, opt)

    opt.src_vocab_size = training_data.dataset.src_vocab_size
    opt.tgt_vocab_size = training_data.dataset.tgt_vocab_size

    #========= Preparing Model =========#
    print(opt)
    device = torch.device('cuda' if opt.cuda else 'cpu')

    # model
    opt.bidirectional = True
    encoder = EncoderRNN(opt.src_vocab_size,
                         opt.max_token_seq_len,
                         opt.d_model,
                         bidirectional=opt.bidirectional,
                         variable_lengths=True)
    decoder = DecoderRNN(opt.tgt_vocab_size,
                         opt.max_token_seq_len,
                         opt.d_model * 2 if opt.bidirectional else opt.d_model,
                         n_layers=opt.n_layer,
                         dropout_p=opt.dropout,
                         use_attention=True,
                         bidirectional=opt.bidirectional,
                         eos_id=Constants.BOS,
                         sos_id=Constants.EOS)
    seq2seq = Seq2seq(encoder, decoder).to(device)
    for param in seq2seq.parameters():
        param.data.uniform_(-0.08, 0.08)

    seq2seq = nn.DataParallel(seq2seq)

    # loss
    weight = torch.ones(opt.tgt_vocab_size)
    pad = Constants.PAD
    loss = Perplexity(weight, pad)
    if opt.cuda:
        loss.cuda()

    # optimizer
    optimizer = Optimizer(torch.optim.Adam(seq2seq.parameters()),
                          max_grad_norm=5)

    train(seq2seq, training_data, validation_data, loss, optimizer, device,
          opt)