Ejemplo n.º 1
0
def do_train(config):
    paddle.set_device("gpu" if config.n_gpu else "cpu")
    if paddle.distributed.get_world_size() > 1:
        paddle.distributed.init_parallel_env()
    set_seed(config)

    base_graph, term_ids = load_data(config.graph_work_path)
    collate_fn = partial(
        batch_fn,
        samples=config.samples,
        base_graph=base_graph,
        term_ids=term_ids)

    mode = 'train'
    train_ds = TrainData(config.graph_work_path)
    model = ErnieSageForLinkPrediction.from_pretrained(
        config.model_name_or_path, config=config)
    model = paddle.DataParallel(model)

    train_loader = GraphDataLoader(
        train_ds,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=config.sample_workers,
        collate_fn=collate_fn)

    optimizer = paddle.optimizer.Adam(
        learning_rate=config.lr, parameters=model.parameters())

    global_step = 0
    tic_train = time.time()
    for epoch in range(config.epoch):
        for step, (graphs, datas) in enumerate(train_loader):
            global_step += 1
            loss, outputs = model(graphs, datas)
            if global_step % config.log_per_step == 0:
                logger.info(
                    "global step %d, epoch: %d, batch: %d, loss: %f, speed: %.2f step/s"
                    % (global_step, epoch, step, loss,
                       config.log_per_step / (time.time() - tic_train)))
                tic_train = time.time()
            loss.backward()
            optimizer.step()
            optimizer.clear_grad()
            if global_step % config.save_per_step == 0:
                if (not config.n_gpu > 1) or paddle.distributed.get_rank() == 0:
                    output_dir = os.path.join(config.output_path,
                                              "model_%d" % global_step)
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model._layers.save_pretrained(output_dir)
    if (not config.n_gpu > 1) or paddle.distributed.get_rank() == 0:
        output_dir = os.path.join(config.output_path, "last")
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        model._layers.save_pretrained(output_dir)
Ejemplo n.º 2
0
def train(**kwargs):

    # attributes
    for k, v in kwargs.items():
        setattr(opt, k, v)

    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')
    torch.backends.cudnn.enabled = False

    # dataset
    if opt.chinese:
        mydata = TrainData(opt.chinese_data_path, opt.conversation_path,
                           opt.chinese_results_path, opt.chinese, opt.fb,
                           opt.prev_sent, True)
    else:
        mydata = TrainData(opt.data_path, opt.conversation_path,
                           opt.results_path, opt.chinese, opt.fb,
                           opt.prev_sent, True)

    # models
    if opt.attn:
        seq2seq = NewSeq2seqAttention(num_tokens=mydata.data.num_tokens,
                                      opt=opt,
                                      sos_id=mydata.data.word2id["<START>"])

        if opt.model_attention_path:
            seq2seq.load_state_dict(
                torch.load(opt.model_attention_path, map_location="cpu"))
            print("Pretrained model has been loaded.\n")
    else:
        seq2seq = NewSeq2seq(num_tokens=mydata.data.num_tokens,
                             opt=opt,
                             sos_id=mydata.data.word2id["<START>"])

        if opt.chinese:
            if opt.chinese_model_path:
                seq2seq.load_state_dict(
                    torch.load(opt.chinese_model_path, map_location="cpu"))
                print("Pretrained model has been loaded.\n")
        else:
            if opt.model_path:
                seq2seq.load_state_dict(
                    torch.load(opt.model_path, map_location="cpu"))
                print("Pretrained model has been loaded.\n")

    seq2seq = seq2seq.to(device)

    #=============================================================#

    optimizer = RMSprop(seq2seq.parameters(), lr=opt.learning_rate)
    criterion = nn.CrossEntropyLoss().to(device)

    for epoch in range(opt.epochs):
        print("epoch %d:" % epoch)
        mini_batches = mydata._mini_batches(opt.batch_size)

        for ii, (ib, tb) in enumerate(mini_batches):

            ib = ib.to(device)
            tb = tb.to(device)

            optimizer.zero_grad()
            decoder_outputs, decoder_hidden1, decoder_hidden2 = seq2seq(ib, tb)

            # Its own last output
            a = []
            b = []
            for t in range(opt.mxlen):
                _, indices = torch.topk(decoder_outputs[t][0], 1)
                a.append(mydata.data.id2word[ib[t][0].item()])
                b.append(mydata.data.id2word[indices[0].item()])
            print(a)
            print(b)

            # Reshape the outputs
            b = decoder_outputs.size(1)
            t = decoder_outputs.size(0)
            targets = Variable(torch.zeros(t, b)).to(
                device)  # (time_steps,batch_size)
            targets[:-1, :] = tb[1:, :]

            targets = targets.contiguous().view(-1)  # (time_steps*batch_size)
            decoder_outputs = decoder_outputs.view(
                b * t, -1)  # S = (time_steps*batch_size) x V
            loss = criterion(decoder_outputs, targets.long())

            if ii % 1 == 0:
                print("Current Loss:", loss.data.item())

            loss.backward()
            optimizer.step()
        if opt.chinese:
            save_path = "checkpoints/chinese-epoch-%s.pth" % epoch
        else:
            save_path = "checkpoints/epoch-%s.pth" % epoch

        torch.save(seq2seq.state_dict(), save_path)
Ejemplo n.º 3
0
def test(**kwargs):

    # attributes
    for k, v in kwargs.items():
        setattr(opt, k, v)

    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')
    torch.backends.cudnn.enabled = False

    prev_sentence = ""

    while True:

        data = input('You say: ')
        if data == "exit":
            break
        if opt.prev_sent == 2:
            data = (prev_sentence + data) if not opt.chinese else data

        # Dataset
        if opt.chinese:
            data = list(convert(data, 't2s'))
            mydata = TrainData(opt.chinese_data_path, opt.conversation_path,
                               opt.chinese_results_path, opt.chinese, opt.fb,
                               opt.prev_sent, True)
        else:
            data = ' '.join(data.split(' '))
            mydata = TrainData(opt.data_path, opt.conversation_path,
                               opt.results_path, opt.chinese, opt.fb,
                               opt.prev_sent, True)

        # models
        if opt.attn:
            seq2seq = NewSeq2seqAttention(
                num_tokens=mydata.data.num_tokens,
                opt=opt,
                sos_id=mydata.data.word2id["<START>"])
            if opt.model_attention_path:
                seq2seq.load_state_dict(
                    torch.load(opt.model_attention_path, map_location="cpu"))

        elif opt.rl:
            seq2seq = NewSeq2seqRL(num_tokens=mydata.data.num_tokens,
                                   opt=opt,
                                   sos_id=mydata.data.word2id["<START>"])
            if opt.model_rl_path:
                seq2seq.load_state_dict(
                    torch.load(opt.model_rl_path, map_location="cpu"))
            seq2seq = seq2seq.to(opt.device)

        else:
            seq2seq = NewSeq2seq(num_tokens=mydata.data.num_tokens,
                                 opt=opt,
                                 sos_id=mydata.data.word2id["<START>"])

            if opt.chinese:
                if opt.chinese_model_path:
                    seq2seq.load_state_dict(
                        torch.load(opt.chinese_model_path, map_location="cpu"))
            else:
                if opt.model_path:
                    seq2seq.load_state_dict(
                        torch.load(opt.model_path, map_location="cpu"))

        seq2seq = seq2seq.to(device)

        # Predict
        encoder_data = mydata._test_batch(
            data, 2 * opt.mxlen if not opt.chinese else opt.mxlen).to(device)
        # encoder_data = torch.LongTensor([[  1,  30, 112,  10,   3,   2,   1, 645, 131,   7,  25,   7, 146, 584,
        # 871, 207,  16,   3,   2,   0,   0,   0,   0,   0,   0,   0,   0,   0,
        #   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0]]).transpose(1,0)
        decoded_indices, decoder_hidden1, decoder_hidden2 = seq2seq.evaluation(
            encoder_data)

        toks_to_replace = {
            "i": "I",
            "im": "I'm",
            "id": "I'd",
            "ill": "I'll",
            "iv": "I'v",
            "hes": "he's",
            "shes": "she's",
            "youre": "you're",
            "its": "it's",
            "dont": "don't",
            "youd": "you'd",
            "cant": "can't",
            "thats": "that's",
            "isnt": "isn't",
            "didnt": "didn't",
            "hows": "how's",
            "ive": "I've"
        }

        decoded_sequence = ""
        for idx in decoded_indices:
            idx = idx.item() if opt.rl else idx
            sampled_tok = mydata.data.id2word[idx]
            if sampled_tok == "<START>":
                continue
            elif sampled_tok == "<EOS>":
                break
            else:
                if not opt.chinese:
                    if sampled_tok in toks_to_replace:
                        sampled_tok = toks_to_replace[sampled_tok]
                    decoded_sequence += sampled_tok + ' '
                else:
                    decoded_sequence += sampled_tok

        print("WayneBot:",decoded_sequence if not opt.chinese \
            else convert(decoded_sequence,'s2t').replace("雞仔","我").replace("主人","陛下").replace("主子","陛下"))
        prev_sentence = decoded_sequence
Ejemplo n.º 4
0
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--n_epochs', type=int, default=100)
    parser.add_argument('--log_freq', type=int, default=30)
    parser.add_argument('--plot_freq', type=int, default=250)
    parser.add_argument('--save_freq', type=int, default=10)
    # eval setting
    parser.add_argument('--val_fraction', type=float, default=0.1)
    parser.add_argument('--eval_batch_size', type=int, default=32)
    parser.add_argument('--eval_plot_freq', type=int, default=10)
    args = parser.parse_args()

    model = DeepGMR(args)
    if torch.cuda.is_available():
        model.cuda()

    data = TrainData(args.data_file, args)
    ids = np.random.permutation(len(data))
    n_val = int(args.val_fraction * len(data))
    train_data = Subset(data, ids[n_val:])
    valid_data = Subset(data, ids[:n_val])

    train_loader = DataLoader(train_data,
                              args.batch_size,
                              drop_last=True,
                              shuffle=True)
    valid_loader = DataLoader(valid_data, args.eval_batch_size, drop_last=True)

    optimizer = torch.optim.Adam(model.parameters(), args.lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           factor=0.5,
                                                           min_lr=1e-6)
Ejemplo n.º 5
0
def train_RL():

    torch.backends.cudnn.enabled = False

    dull_set = [
        "I don't know what you're talking about.", "I don't know.",
        "You don't know.", "You know what I mean.", "I know what you mean.",
        "You know what I'm saying.", "You don't know anything."
    ]
    ones_reward = torch.ones(opt.mxlen).to(opt.device)

    # dataset
    mydata = TrainData(opt.data_path,
                       opt.conversation_path,
                       opt.results_path,
                       chinese=False,
                       fb=False,
                       prev_sent=opt.prev_sent)

    # Ease of answering data
    dull_target_set = make_batch(mydata.data, dull_set, opt.mxlen)
    dull_target_set = dull_target_set.permute(1, 0)
    dull_target_data = Variable(torch.LongTensor(opt.batch_size,
                                                 opt.mxlen)).to(opt.device)
    for i in range(opt.batch_size):
        dull_target_data[i] = dull_target_set[np.random.randint(len(dull_set))]
    dull_target_data = dull_target_data.permute(1, 0)

    # models
    # 1. RL model
    seq2seq_rl = NewSeq2seqRL(num_tokens=mydata.data.num_tokens,
                              opt=opt,
                              sos_id=mydata.data.word2id["<START>"])
    if opt.model_rl_path:
        seq2seq_rl.load_state_dict(
            torch.load(opt.model_rl_path, map_location="cpu"))
        print("Pretrained RL model has been loaded.")
    seq2seq_rl = seq2seq_rl.to(opt.device)

    # 2. Normal model
    seq2seq_normal = NewSeq2seq(num_tokens=mydata.data.num_tokens,
                                opt=opt,
                                sos_id=mydata.data.word2id["<START>"])
    if opt.model_path:
        seq2seq_normal.load_state_dict(
            torch.load(opt.model_path, map_location="cpu"))
        print("Pretrained Normal model has been loaded.")
    seq2seq_normal = seq2seq_normal.to(opt.device)

    #=============================================================#

    optimizer = RMSprop(seq2seq_rl.parameters(), lr=opt.learning_rate)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(opt.epochs):
        print("epoch %d:" % epoch)
        mini_batches = mydata._mini_batches(opt.batch_size)

        for ii, (ib, tb) in enumerate(mini_batches):

            # Skip the last batch to avoid error
            if ib.size(1) != opt.batch_size:
                continue

            ib = ib.to(opt.device)
            tb = tb.to(opt.device)

            optimizer.zero_grad()

            # First evaluate an output
            action_words, _, _ = seq2seq_rl.evaluation(ib)

            # Ease of answering data
            dull_outputs, dull_loss, dull_entropies = seq2seq_rl(
                action_words, dull_target_data, ones_reward)

            # Semantic Coherence: Forward
            forward_outputs, forward_loss, forward_entropies = seq2seq_rl(
                ib, action_words, ones_reward)

            # Semantic Coherence: Backward
            backward_outputs, _, _ = seq2seq_normal(action_words,
                                                    ib[:opt.mxlen])
            backward_targets = ib[:opt.mxlen]  # (time_steps, batch_size)
            backward_entropies = []
            for t in range(opt.mxlen):
                backward_loss = criterion(backward_outputs[t],
                                          backward_targets[t].long())
                backward_entropies.append(backward_loss)

            rewards = count_rewards(dull_entropies, forward_entropies,
                                    backward_entropies)

            # Add rewards to train the data
            decoder_outputs, pg_loss, entropies = seq2seq_rl(ib, tb, rewards)

            if ii % 1 == 0:
                print("Current Loss:", pg_loss.data.item())

            pg_loss.backward()
            optimizer.step()

            save_path = "checkpoints/rl-epoch-%s.pth" % epoch
            if (epoch + 1) % 10 == 0:
                torch.save(seq2seq_rl.state_dict(), save_path)
Ejemplo n.º 6
0
 def __init__(self, confPath=""):
     Network.__init__(self, confPath)
     self.data = TrainData(
         self.conf, self.conf["epochWalked"] / self.conf["updateEpoch"])
     self.data.check_data()
     self.GPU0 = '0'