Example #1
0
def train_model(model, opt):

    print("training model...")
    model.train()
    start = time.time()
    if opt.checkpoint > 0:
        cptime = time.time()

    for epoch in range(opt.epochs):

        total_loss = 0
        if opt.floyd is False:
            print("   %dm: epoch %d [%s]  %d%%  loss = %s" %\
            ((time.time() - start)//60, epoch + 1, "".join(' '*20), 0, '...'), end='\r')

        if opt.checkpoint > 0:
            torch.save(model.state_dict(), 'weights/model_weights')

        for i, batch in enumerate(opt.train):

            src = batch.src.transpose(0, 1)
            trg = batch.trg.transpose(0, 1)
            trg_input = trg[:, :-1]
            src_mask, trg_mask = create_masks(src, trg_input, opt)
            preds = model(src, trg_input, src_mask, trg_mask)
            ys = trg[:, 1:].contiguous().view(-1)
            opt.optimizer.zero_grad()
            loss = F.cross_entropy(preds.view(-1, preds.size(-1)),
                                   ys,
                                   ignore_index=opt.trg_pad)
            loss.backward()
            opt.optimizer.step()
            if opt.SGDR == True:
                opt.sched.step()

            total_loss += loss.item()

            if (i + 1) % opt.printevery == 0:
                p = int(100 * (i + 1) / opt.train_len)
                avg_loss = total_loss / opt.printevery
                if opt.floyd is False:
                    print("   %dm: epoch %d [%s%s]  %d%%  loss = %.3f" %\
                    ((time.time() - start)//60, epoch + 1, "".join('#'*(p//5)), "".join(' '*(20-(p//5))), p, avg_loss), end='\r')
                else:
                    print("   %dm: epoch %d [%s%s]  %d%%  loss = %.3f" %\
                    ((time.time() - start)//60, epoch + 1, "".join('#'*(p//5)), "".join(' '*(20-(p//5))), p, avg_loss))
                total_loss = 0

            if opt.checkpoint > 0 and (
                (time.time() - cptime) // 60) // opt.checkpoint >= 1:
                torch.save(model.state_dict(), 'weights/model_weights')
                cptime = time.time()


        print("%d s: epoch %d [%s%s]  %d%%  loss = %.3f\nepoch %d complete, loss = %.03f" %\
        ((time.time() - start), epoch + 1, "".join('#'*(100//5)), "".join(' '*(20-(100//5))), 100, avg_loss, epoch + 1, avg_loss))

        print("saving weights to " + opt.output_dir + "/...")
        torch.save(model.state_dict(), f'{opt.output_dir}/model_weights')
        print("weights saved ! ")
Example #2
0
def evaluate(model, iterator, criterion, opt):

    model.eval()
    epoch_loss = 0

    print(f"Evaluating data...")

    with torch.no_grad():

        for i, batch in tqdm(enumerate(iterator)):

            if opt.nmt_model_type == 'transformer':
                src = batch.src.transpose(0, 1)
                trg = batch.trg.transpose(0, 1)
                src_mask, trg_mask = create_masks(src, trg[:, :-1], opt)
                output = model(src, trg[:, :-1], src_mask, trg_mask)
                output = output.contiguous().view(-1, output.shape[-1])
                trg = trg[:, 1:].contiguous().view(-1)
            else:
                src = batch.src
                trg = batch.trg
                output = model(src, trg)
                output = output[1:].view(-1, output.shape[-1])
                trg = trg[1:].view(-1)

            loss = criterion(output, trg)

            epoch_loss += loss.item()

    return epoch_loss / len(iterator)
Example #3
0
def test_model(model, opt):
    print("Testing model...")
    model.eval()
    start = time.time()
    if opt.checkpoint > 0:
        cptime = time.time()

    loss_log = tqdm(total=0, bar_format='{desc}', position=2)
    # for epoch in range(opt.epochs):
    test_loss = 0
    with torch.no_grad():
        for batch_idx, (enc_input, dec_input, dec_output) in enumerate(
                tqdm(opt.test, desc="Iteration", position=0)):
            enc_input = enc_input.to(opt.device)
            dec_input = dec_input.to(opt.device)
            dec_output = dec_output.to(opt.device)

            src_mask, trg_mask = create_masks(enc_input, dec_input, opt)

            preds = model(enc_input, dec_input, src_mask, trg_mask)

            ys = dec_output.contiguous().view(-1)

            loss = F.cross_entropy(preds.view(-1, preds.size(-1)),
                                   ys,
                                   ignore_index=opt.trg_pad)
            test_loss += loss.item()
            test_avg_loss = test_loss / opt.printevery

            if (batch_idx + 1) % opt.printevery == 0:
                p = int(100 * (batch_idx + 1) / opt.test_len)
                test_loss = 0

        print("%dm: Evaluated loss = %.3f\n" %\
        ((time.time() - start)//60, test_avg_loss))
Example #4
0
def eval_epoch(model, valid_data, opt):

    model.eval()

    total_loss, n_word_total, n_word_correct = 0, 0, 0

    with torch.no_grad():
        for i, batch in enumerate(valid_data):

            src = batch[0].to(opt.device)
            trg = batch[1].to(opt.device)

            trg_input = trg[:, :-1]
            trg_output = trg[:, 1:].contiguous().view(-1)

            src_mask, trg_mask = create_masks(src, trg_input, opt)
            preds = model(src, trg_input, src_mask, trg_mask)
            preds = preds.view(-1, preds.size(-1))

            n_correct, n_word = cal_performance(preds, trg_output, opt.trg_pad)

            loss = F.cross_entropy(preds, trg_output, ignore_index=opt.trg_pad)

            n_word_total += n_word
            n_word_correct += n_correct
            total_loss += loss.item()

    return total_loss, n_word_total, n_word_correct
Example #5
0
def train_model(epochs, print_every=100):

    model.train()

    start = time.time()
    temp = start

    total_loss = 0

    for epoch in range(opt.epochs):

        for i, batch in enumerate(train_iter):
            src = batch.English.transpose(0, 1)
            trg = batch.French.transpose(0, 1)
            # the French sentence we input has all words except
            # the last, as it is using each word to predict the next

            trg_input = trg[:, :-1]

            # the words we are trying to predict

            targets = trg[:, 1:].contiguous().view(-1)

            # create function to make masks using mask code above

            src_mask, trg_mask = create_masks(src, trg_input)

            preds = model(src, trg_input, src_mask, trg_mask)

            optim.zero_grad()

            loss = F.cross_entropy(preds.view(-1, preds.size(-1)),
                                   results,
                                   ignore_index=target_pad)
            loss.backward()
            optim.step()

            total_loss += loss.item()

            if (i + 1) % opt.printevery == 0:
                p = int(100 * (i + 1) / opt.train_len)
                avg_loss = total_loss / opt.printevery
                if opt.floyd is False:
                    print("   %dm: epoch %d [%s%s]  %d%%  loss = %.3f" %\
                    ((time.time() - start)//60, epoch + 1, "".join('#'*(p//5)), "".join(' '*(20-(p//5))), p, avg_loss), end='\r')
                else:
                    print("   %dm: epoch %d [%s%s]  %d%%  loss = %.3f" %\
                    ((time.time() - start)//60, epoch + 1, "".join('#'*(p//5)), "".join(' '*(20-(p//5))), p, avg_loss))
                total_loss = 0

            if opt.checkpoint > 0 and (
                (time.time() - cptime) // 60) // opt.checkpoint >= 1:
                torch.save(model.state_dict(), 'weights/model_weights')
                cptime = time.time()


        print("%dm: epoch %d [%s%s]  %d%%  loss = %.3f\nepoch %d complete, loss = %.03f" %\
        ((time.time() - start)//60, epoch + 1, "".join('#'*(100//5)), "".join(' '*(20-(100//5))), 100, avg_loss, epoch + 1, avg_loss))
Example #6
0
    def training_step(self, batch, batch_idx):
        src = batch.src.transpose(0, 1)
        trg = batch.trg.transpose(0, 1)
        trg_input = trg[:, :-1]
        src_mask, trg_mask = create_masks(src, trg_input, self.opt)
        preds = self.transformer(src, trg_input, src_mask, trg_mask)
        ys = trg[:, 1:].contiguous().view(-1)
        self.opt.optimizer.zero_grad()
        loss = F.cross_entropy(preds.view(-1, preds.size(-1)),
                               ys,
                               ignore_index=self.opt.trg_pad)
        self.opt.optimizer.step()
        if self.opt.SGDR == True:
            self.opt.sched.step()

        self.total_loss += loss.item()

        if (batch_idx + 1) % self.opt.printevery == 0:
            p = int(100 * (batch_idx + 1) / self.opt.train_len)
            avg_loss = self.total_loss / self.opt.printevery
            # if opt.floyd is False:
            #     print("   %dm: epoch %d [%s%s]  %d%%  loss = %.3f" % \
            #           ((time.time() - start) // 60, epoch + 1, "".join('#' * (p // 5)),
            #            "".join(' ' * (20 - (p // 5))), p, avg_loss), end='\r')
            # else:
            #     print("   %dm: epoch %d [%s%s]  %d%%  loss = %.3f" % \
            #           ((time.time() - start) // 60, epoch + 1, "".join('#' * (p // 5)),
            #            "".join(' ' * (20 - (p // 5))), p, avg_loss))
            self.total_loss = 0

        # if opt.checkpoint > 0 and ((time.time() - self.cptime) // 60) // opt.checkpoint >= 1:
        #     torch.save(self.transformer.state_dict(), 'weights/model_weights.pth')
        #     self.cptime = time.time()

        if loss < self.best_loss:
            torch.save(self.transformer.state_dict(),
                       'weights/model_weights.pth')
            self.best_loss = loss

        self.log('training_loss',
                 loss.item(),
                 on_step=True,
                 on_epoch=True,
                 prog_bar=True,
                 logger=True)

        return loss
Example #7
0
def eval_model(model, evalloader, opt):
    total_loss = 0
    for i, (src, trg) in enumerate(evalloader.batch_data_generator()):
        with torch.no_grad():
            trg_input = trg[:, :-1] # not include the end of sentence
            src_mask, trg_mask = create_masks(src, trg_input, opt)
            preds = model(src, trg_input, src_mask, trg_mask)
            ys = trg[:, 1:].contiguous().view(-1)
            loss = F.cross_entropy(preds.view(-1, preds.size(-1)), ys.long())
            total_loss += loss.item()
            
            # Uncomment to print the sentences
            batch, seq, word = preds.size()
            sentences = []
            for j in range(batch):
                sentence = ' '.join(evalloader.get_sentence_from_tensor(preds[j]))
                sentences.append(sentence)
            print(sentences[0])

    print("Total Validation loss: {}".format(total_loss))
Example #8
0
def train_epoch(model, optimizer, train_data, opt, epoch, start_time):

    model.train()

    total_loss, n_word_total, n_word_correct = 0, 0, 0
    print("   %dm: epoch %d [%s]  %d%%  loss = %s" % \
          ((time.time() - start_time) // 60, epoch + 1, "".join(' ' * 20), 0, '...'), end='\r')

    for i, batch in enumerate(train_data):

        src = batch[0].to(opt.device)
        trg = batch[1].to(opt.device)

        trg_input = trg[:, :-1]
        trg_output = trg[:, 1:].contiguous().view(-1)

        src_mask, trg_mask = create_masks(src, trg_input, opt)
        preds = model(src, trg_input, src_mask, trg_mask)
        preds = preds.view(-1, preds.size(-1))

        n_correct, n_word = cal_performance(preds, trg_output, opt.trg_pad)

        optimizer.zero_grad()
        loss = F.cross_entropy(preds, trg_output, ignore_index=opt.trg_pad)
        loss.backward()
        optimizer.step()
        if opt.SGDR == True:
            opt.sched.step()

        n_word_total += n_word
        n_word_correct += n_correct
        total_loss += loss.item()

        if (i + 1) % opt.print_every == 0:
            p = int(100 * (i + 1) / len(train_data))
            avg_loss = total_loss / (i + 1)
            print("   %dm: epoch %d [%s%s]  %d%%  loss = %.3f" % \
                  ((time.time() - start_time) // 60, epoch + 1, "".join('#' * (p // 5)), "".join(' ' * (20 - (p // 5))), p,
                   avg_loss))

    return total_loss, n_word_total, n_word_correct
def train_model(model, opt):

    print("training model...")
    model.train()
    start = time.time()

    for epoch in range(opt.epochs):

        total_loss = 0
        print("%dm: epoch %d [%s]  %d%%  loss = %s" %\
        ((time.time() - start)//60, epoch + 1, "".join(' '*20), 0, '...'), end='\r')

        for i, batch in enumerate(opt.train):

            src = batch.src.transpose(0, 1)
            trg = batch.trg.transpose(0, 1)
            trg_input = trg[:, :-1]
            src_mask, trg_mask = create_masks(src, trg_input, opt)
            preds = model(src, trg_input, src_mask, trg_mask)
            ys = trg[:, 1:].contiguous().view(-1)
            opt.optimizer.zero_grad()
            loss = F.cross_entropy(preds.view(-1, preds.size(-1)),
                                   ys,
                                   ignore_index=opt.trg_pad)
            loss.backward()
            opt.optimizer.step()
            if opt.SGDR == True:
                opt.sched.step()

            total_loss += loss.item()

            if (i + 1) % opt.printevery == 0:
                p = int(100 * (i + 1) / opt.train_len)
                avg_loss = total_loss / opt.printevery
                print("%dm: epoch %d [%s%s]  %d%%  loss = %.3f" %\
                ((time.time() - start)//60, epoch + 1, "".join('#'*(p//5)), "".join(' '*(20-(p//5))), p, avg_loss), end='\r')
                total_loss = 0


        print("%dm: epoch %d [%s%s]  %d%%  loss = %.3f\nepoch %d complete, loss = %.03f" %\
        ((time.time() - start)//60, epoch + 1, "".join('#'*(100//5)), "".join(' '*(20-(100//5))), 100, avg_loss, epoch + 1, avg_loss))
Example #10
0
def train_model(model, opt):
    print("training model...")
    model.train()
    # start = time.time()
    if opt.checktime > 0:
        cptime = time.time()
    for epoch in range(opt.epochs):
        print("epoch", epoch)
        total_loss = 0
        for i, batch in enumerate(opt.train):
            src = batch.src.transpose(0, 1)
            trg = batch.trg.transpose(0, 1)
            trg_input = trg[:, :-1]
            if opt.device == 0:
                src = src.cuda()
                trg_input = trg_input.cuda()
            src_mask, trg_mask = create_masks(src, trg_input, opt)
            preds = model(src, trg_input, src_mask, trg_mask)
            if opt.device == 0:
                ys = trg[:, 1:].contiguous().view(-1).cuda()
            else:
                ys = trg[:, 1:].contiguous().view(-1)
            opt.optimizer.zero_grad()
            loss = F.cross_entropy(preds.view(-1, preds.size(-1)),
                                   ys,
                                   ignore_index=opt.trg_pad)
            loss.backward()
            opt.optimizer.step()
            if opt.SGDR == True:
                opt.sched.step()
            total_loss += loss.item()

        if (epoch + 1) % opt.checkpoint == 0:
            print(" epoch %d  loss = %.06f   " % (epoch, loss))
        if (epoch + 1) % opt.checkpoint == 0 or (
            (time.time() - cptime) // 60) // opt.checkpoint >= 1:
            model_path = config.weights + '/model_weights_' + str(epoch + 1)
            torch.save(model.state_dict(), model_path)
            print("%d models has saved %s" % (epoch, model_path))
            cptime = time.time()
Example #11
0
def train_model(model, trainloader, evalloader, opt):
    for epoch in range(opt.epochs):
        total_loss = 0
        for i, (src, trg) in enumerate(trainloader.batch_data_generator()):
            trg_input = trg[:, :-1] # not include the end of sentence
            src_mask, trg_mask = create_masks(src, trg_input, opt)
            preds = model(src, trg_input, src_mask, trg_mask)
            ys = trg[:, 1:].contiguous().view(-1)
            opt.optimizer.zero_grad()
            loss = F.cross_entropy(preds.view(-1, preds.size(-1)), ys.long())
            if i % opt.log_frequency == 0:
                print("Epoch [{}][{}] Batch [{}] Loss = {}".format(epoch, opt.epochs, i, loss.item()))
            loss.backward()
            opt.optimizer.step()
            if opt.SGDR == True: 
                opt.sched.step()
            total_loss += loss.item()

        print("Epoch: [{}]/[{}], Loss: {}".format(epoch, opt.epochs, total_loss))
        if epoch % opt.save_freq == opt.save_freq - 1:
            torch.save(model.state_dict(), "{}/activitynetmodel_{}.pth".format(opt.model_save_dir, epoch))
            eval_model(model, evalloader, opt)
Example #12
0
                    操作_分_枝 = np.array(操作_分_表[i * 树枝:(i + 1) * 树枝])
                    图片_分_枝 = np.array(图片_分_表[i * 树枝:(i + 1) * 树枝])
                    目标输出_分_枝 = np.array(目标输出_分_表[i * 树枝:(i + 1) * 树枝])

                else:
                    操作_分_枝 = np.array(操作_分_表[i * 树枝:len(操作_分_表)])
                    图片_分_枝 = np.array(图片_分_表[i * 树枝:len(图片_分_表)],
                                      dtype=np.float32)
                    目标输出_分_枝 = np.array(目标输出_分_表[i * 树枝:len(目标输出_分_表)])
                    循环 = False

                操作_分_torch = torch.from_numpy(操作_分_枝).cuda(device)
                图片_分_torch = torch.from_numpy(图片_分_枝).cuda(device)
                目标输出_分_torch = torch.from_numpy(目标输出_分_枝).cuda(device)

                src_mask, trg_mask = create_masks(操作_分_torch, 操作_分_torch,
                                                  device)
                if 图片_分_torch.shape[0] != 操作_分_torch.shape[0]:
                    continue
                输出_实际_A = model(图片_分_torch, 操作_分_torch, trg_mask)
                lin = 输出_实际_A.view(-1, 输出_实际_A.size(-1))
                optimizer.zero_grad()
                loss = F.cross_entropy(lin,
                                       目标输出_分_torch.contiguous().view(-1),
                                       ignore_index=-1)
                if 计数 % 1 == 0:
                    print(loss)

                    time_end = time.time()
                    用时 = time_end - time_start

                    _, 抽样 = torch.topk(输出_实际_A, k=1, dim=-1)
Example #13
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('-load_weights', required=True)
    parser.add_argument('-k', type=int, default=3)
    parser.add_argument('-max_len', type=int, default=32)
    parser.add_argument('-d_model', type=int, default=512)
    parser.add_argument('-n_layers', type=int, default=6)
    parser.add_argument('-heads', type=int, default=8)
    parser.add_argument('-dropout', type=int, default=0.1)

    opt = parser.parse_args()

    opt.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    print(f'Load Tokenizer and Vocab...')
    sp_tokenizer = Tokenizer(is_train=False, model_prefix='spm')
    sp_vocab = sp_tokenizer.vocab

    print(f'Load the extended vocab...')
    vocab = Vocabulary.load_vocab('./data/vocab')

    ######################TEST DATA######################
    # fitting the test dataset dir
    test_data_dir = [
        './data/test/newstest2014_en', './data/test/newstest2014_de'
    ]
    test_dataset = Our_Handler(src_path=test_data_dir[0],
                               tgt_path=test_data_dir[1],
                               vocab=vocab,
                               tokenizer=sp_tokenizer,
                               max_len=32,
                               is_test=True)

    test_dataloader = DataLoader(test_dataset,
                                 batch_size=64,
                                 shuffle=False,
                                 drop_last=True)
    opt.test = test_dataloader
    opt.test_len = len(test_dataloader)
    ####################################################
    model = get_model(opt, len(vocab), len(vocab))

    model.eval()
    opt.src_pad = opt.trg_pad = vocab.pad_index

    test_loss = 0.
    test_ppl = 0.
    for batch_idx, (enc_input, dec_input, dec_output) in enumerate(opt.test):

        if batch_idx == 1:
            break

        enc_input = enc_input.to(opt.device)
        dec_input = dec_input.to(opt.device)
        dec_output = dec_output.to(opt.device)

        src_mask, trg_mask = create_masks(enc_input, dec_input, opt)

        with torch.no_grad():
            preds = model(enc_input, dec_input, src_mask, trg_mask)

        ys = dec_output.contiguous().view(-1)

        loss = F.cross_entropy(preds.view(-1, preds.size(-1)),
                               ys,
                               ignore_index=opt.trg_pad)

        test_loss += loss.item()
        test_ppl += np.exp(loss.item())

    avg_test_loss = test_loss / len(opt.test)
    avg_ppl = test_ppl / len(opt.test)
    print(f'Test loss: {avg_test_loss:.3f}, Test perpelxity: {avg_ppl:.3f}')
Example #14
0
def train_model(model, opt):

    print("training model...")
    # model.train()
    start = time.time()
    if opt.checkpoint > 0:
        cptime = time.time()

    early_stopping = EarlyStopping(patience=5, verbose=1)

    loss_log = tqdm(total=0, bar_format='{desc}', position=2)

    training_loss_list = []
    val_loss_list = []

    for epoch in range(opt.epochs):
        # for epoch in trange(opt.epochs, desc="Epoch", position=0):

        if opt.floyd is False:
            print("   %dm: epoch %d [%s]  %d%%  loss = %s" %\
            ((time.time() - start)//60, epoch + 1, "".join(' '*20), 0, '...'), end='\r')

        if opt.checkpoint > 0:
            torch.save(model.state_dict(), 'weights/model_weights')

        # Training the model
        model.train()
        total_loss = 0

        for batch_idx, (enc_input, dec_input,
                        dec_output) in enumerate(opt.train):
            # for batch_idx, (enc_input, dec_input, dec_output) in enumerate(tqdm(opt.train, desc="Iteration", ncols=100, position=1)):
            enc_input = enc_input.to(opt.device)
            dec_input = dec_input.to(opt.device)
            dec_output = dec_output.to(opt.device)

            #trg_input = trg[:, :-1]
            src_mask, trg_mask = create_masks(enc_input, dec_input, opt)

            preds = model(enc_input, dec_input, src_mask, trg_mask)

            #ys = trg[:, 1:].contiguous().view(-1)
            ys = dec_output.contiguous().view(-1)

            opt.optimizer.zero_grad()
            loss = F.cross_entropy(preds.view(-1, preds.size(-1)),
                                   ys,
                                   ignore_index=opt.trg_pad)
            loss.backward()
            opt.optimizer.step()
            if opt.SGDR == True:
                opt.sched.step()

            total_loss += loss.item()
            avg_loss = total_loss / opt.printevery

            if (batch_idx + 1) % opt.printevery == 0:
                p = int(100 * (batch_idx + 1) / opt.train_len)
                print("   %dm: epoch %d [%s%s]  %d%%  loss = %.3f" %
                      ((time.time() - start) // 60, epoch + 1, "".join(
                          '#' * (p // 5)), "".join(' ' *
                                                   (20 -
                                                    (p // 5))), p, avg_loss))
                training_loss_list.append(avg_loss)
                # if opt.floyd is False:
                #     print("   %dm: epoch %d [%s%s]  %d%%  loss = %.3f" %\
                #     ((time.time() - start)//60, epoch + 1, "".join('#'*(p//5)), "".join(' '*(20-(p//5))), p, avg_loss), end='\r')
                # else:
                #     print("   %dm: epoch %d [%s%s]  %d%%  loss = %.3f" %\
                #     ((time.time() - start)//60, epoch + 1, "".join('#'*(p//5)), "".join(' '*(20-(p//5))), p, avg_loss))
                total_loss = 0

            if opt.checkpoint > 0 and (
                (time.time() - cptime) // 60) // opt.checkpoint >= 1:
                torch.save(model.state_dict(), 'weights/model_weights')
                cptime = time.time()

        ## Validating the model
        model.eval()
        val_loss = 0
        val_step = 0
        early_stopped = False

        with torch.no_grad():
            for batch_idx, (enc_input, dec_input,
                            dec_output) in enumerate(opt.validation):
                # for batch_idx, (enc_input, dec_input, dec_output) in enumerate(tqdm(opt.validation, desc="Iteration", ncols=100, position=1)):
                enc_input = enc_input.to(opt.device)
                dec_input = dec_input.to(opt.device)
                dec_output = dec_output.to(opt.device)

                src_mask, trg_mask = create_masks(enc_input, dec_input, opt)

                preds = model(enc_input, dec_input, src_mask, trg_mask)

                ys = dec_output.contiguous().view(-1)

                loss = F.cross_entropy(preds.view(-1, preds.size(-1)),
                                       ys,
                                       ignore_index=opt.trg_pad)
                val_loss += loss.item()
                val_step += 1

        val_loss = val_loss / val_step
        val_loss_list.append(val_loss)
        print("epoch %d, loss = %.3f" % (epoch + 1, val_loss))

        if early_stopping.validate(val_loss):
            break

    write_csv_file(training_loss_list, 'train_loss')
    write_csv_file(val_loss_list, 'val_loss')
Example #15
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-data_path', required=True)
    parser.add_argument('-output_dir', required=True)
    parser.add_argument('-no_cuda', action='store_true')
    parser.add_argument('-SGDR', action='store_true')
    parser.add_argument('-val_check_every_n', type=int, default=3)
    parser.add_argument('-calculate_val_loss', action='store_true')
    parser.add_argument('-val_forward_pass', action='store_true')
    parser.add_argument('-tensorboard_graph', action='store_true')
    parser.add_argument('-alex', action='store_true')
    parser.add_argument('-compositional_eval', action='store_true')
    parser.add_argument('-char_tokenization', action='store_true')
    parser.add_argument('-n_val', type=int, default=1000)
    parser.add_argument('-n_test', type=int, default=1000)
    parser.add_argument('-do_test', action='store_true')
    parser.add_argument('-epochs', type=int, default=50)
    parser.add_argument('-d_model', type=int, default=512)
    parser.add_argument('-n_layers', type=int, default=6)
    parser.add_argument('-heads', type=int, default=8)
    parser.add_argument('-dropout', type=int, default=0.1)
    parser.add_argument('-batchsize', type=int, default=3000)
    parser.add_argument('-printevery', type=int, default=100)
    parser.add_argument('-lr', type=int, default=0.0001)
    parser.add_argument('-load_weights')
    parser.add_argument('-create_valset', action='store_true')
    parser.add_argument('-max_strlen', type=int, default=512)
    parser.add_argument('-floyd', action='store_true')
    parser.add_argument('-checkpoint', type=int, default=0)

    opt = parser.parse_args()

    opt.device = 0 if opt.no_cuda is False else -1
    if opt.device == 0:
        assert torch.cuda.is_available()
        if opt.alex:
            torch.cuda.set_device(1)

    read_data(opt)
    SRC, TRG = create_fields(opt)
    opt.train, opt.val = create_dataset(opt, SRC, TRG)
    model = get_model(opt, len(SRC.vocab), len(TRG.vocab), SRC)

    if opt.tensorboard_graph:
        writer = SummaryWriter('runs')
        for i, batch in enumerate(opt.train):
            src = batch.src.transpose(0, 1).cuda()
            trg = batch.trg.transpose(0, 1).cuda()
            trg_input = trg[:, :-1]
            src_mask, trg_mask = create_masks(src, trg_input, opt)
            writer.add_graph(model, (src, trg_input, src_mask, trg_mask))
            break
        writer.close()

    # beam search parameters
    opt.k = 1
    opt.max_len = opt.max_strlen

    opt.optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr, betas=(0.9, 0.98), eps=1e-9)
    opt.scheduler = ReduceLROnPlateau(opt.optimizer, factor=0.5, patience=5, verbose=True)

    if opt.SGDR:
        opt.sched = CosineWithRestarts(opt.optimizer, T_max=opt.train_len)

    if opt.checkpoint > 0:
        print(
            "model weights will be saved every %d minutes and at end of epoch to directory weights/" % (opt.checkpoint))

    train_model(model, opt, SRC, TRG)

    if opt.floyd is False:
        promptNextAction(model, opt, SRC, TRG)
Example #16
0
def train_model(model, opt, SRC, TRG):
    print("training model...")
    model.train()
    start = time.time()
    if opt.checkpoint > 0:
        cptime = time.time()
                 
    for epoch in range(opt.epochs):
        model.train()
        total_loss = 0
        errors_per_epoch = 0
        if opt.floyd is False:
            print("   %dm: epoch %d [%s]  %d%%  loss = %s" %\
            ((time.time() - start)//60, epoch + 1, "".join(' '*20), 0, '...'), end='\r')
        
        if opt.checkpoint > 0:
            torch.save(model.state_dict(), 'weights/model_weights')

        for i, batch in enumerate(opt.train):
            src = batch.src.transpose(0, 1).cuda()
            trg = batch.trg.transpose(0, 1).cuda()

            trg_input = trg[:, :-1]
            src_mask, trg_mask = create_masks(src, trg_input, opt)
            preds = model(src, trg_input, src_mask, trg_mask)
            ys = trg[:, 1:].contiguous().view(-1)

            opt.optimizer.zero_grad()
            loss = F.cross_entropy(preds.view(-1, preds.size(-1)), ys, ignore_index=opt.trg_pad)
            loss.backward()
            opt.optimizer.step()

            if opt.SGDR:
                opt.sched.step()

            total_loss += loss.item()

            if (i + 1) % opt.printevery == 0:
                 p = int(100 * (i + 1) / opt.train_len)
                 avg_loss = total_loss/opt.printevery
                 if opt.floyd is False:
                    print("   %dm: epoch %d [%s%s]  %d%%  loss = %.3f" %\
                    ((time.time() - start)//60, epoch + 1, "".join('#'*(p//5)), "".join(' '*(20-(p//5))), p, avg_loss), end='\r')
                 else:
                    print("   %dm: epoch %d [%s%s]  %d%%  loss = %.3f" %\
                    ((time.time() - start)//60, epoch + 1, "".join('#'*(p//5)), "".join(' '*(20-(p//5))), p, avg_loss))
                 total_loss = 0

            if opt.checkpoint > 0 and ((time.time()-cptime)//60) // opt.checkpoint >= 1:
                torch.save(model.state_dict(), 'weights/model_weights')
                cptime = time.time()

        print("%dm: epoch %d [%s%s]  %d%%  loss = %.3f\nepoch %d complete, loss = %.03f" %\
        ((time.time() - start)//60, epoch + 1, "".join('#'*(100//5)), "".join(' '*(20-(100//5))), 100, avg_loss, epoch + 1, avg_loss))

        print('errors per epoch:', errors_per_epoch)
        if opt.calculate_val_loss:
            model.eval()
            val_losses = []
            for i, batch in enumerate(opt.val):
                src = batch.src.transpose(0, 1).cuda()
                trg = batch.trg.transpose(0, 1).cuda()
                trg_input = trg[:, :-1]
                src_mask, trg_mask = create_masks(src, trg_input, opt)
                preds = model(src, trg_input, src_mask, trg_mask)
                ys = trg[:, 1:].contiguous().view(-1)
                opt.optimizer.zero_grad()
                loss = F.cross_entropy(preds.view(-1, preds.size(-1)), ys, ignore_index=opt.trg_pad)
                val_losses.append(loss.item())

            print('validation loss:', sum(val_losses)/len(val_losses), '\n')

        if (epoch + 1) % opt.val_check_every_n == 0:
            model.eval()
            val_acc, val_success = 0, 0
            val_data = zip_io_data(opt.data_path + '/val')
            for j, e in enumerate(val_data[:opt.n_val]):
                e_src, e_tgt = e[0], e[1]

                if opt.compositional_eval:
                    controller = eval_split_input(e_src)
                    intermediates = []
                    comp_failure = False
                    for controller_input in controller:
                        if len(controller_input) == 1:
                            controller_src = controller_input[0]

                        else:
                            controller_src = ''
                            for src_index in range(len(controller_input) - 1):
                                controller_src += intermediates[controller_input[src_index]] + ' @@SEP@@ '
                            controller_src += controller_input[-1]
                            controller_src = remove_whitespace(controller_src)

                        indexed = []
                        sentence = SRC.preprocess(controller_src)
                        for tok in sentence:
                            if SRC.vocab.stoi[tok] != 0:
                                indexed.append(SRC.vocab.stoi[tok])
                            else:
                                comp_failure = True
                                break
                        if comp_failure:
                            break

                        sentence = Variable(torch.LongTensor([indexed]))
                        if opt.device == 0:
                            sentence = sentence.cuda()

                        try:
                            sentence = beam_search(sentence, model, SRC, TRG, opt)
                            intermediates.append(sentence)
                        except Exception as e:
                            comp_failure = True

                            break

                    if not comp_failure:
                        try:
                            val_acc += simple_em(intermediates[-1], e_tgt)
                            val_success += 1
                        except Exception as e:
                            continue
                else:
                    sentence = SRC.preprocess(e_src)
                    indexed = [SRC.vocab.stoi[tok] for tok in sentence]

                    sentence = Variable(torch.LongTensor([indexed]))
                    if opt.device == 0:
                        sentence = sentence.cuda()

                    try:
                        sentence = beam_search(sentence, model, SRC, TRG, opt)
                    except Exception as e:
                        continue
                    try:
                        val_acc += simple_em(sentence, e_tgt)
                        val_success += 1
                    except Exception as e:
                        continue

            if val_success == 0:
                val_success = 1
            val_acc = val_acc / val_success
            print('epoch', epoch, '- val accuracy:', round(val_acc * 100, 2))
            print()
            opt.scheduler.step(val_acc)

        if epoch == opt.epochs - 1 and opt.do_test:
            model.eval()
            test_data = zip_io_data(opt.data_path + '/test')
            test_predictions = ''
            test_acc, test_success = 0, 0
            for j, e in enumerate(test_data[:opt.n_test]):
                if (j + 1) % 10000 == 0:
                    print(round(j/len(test_data) * 100, 2), '% complete with testing')
                e_src, e_tgt = e[0], e[1]

                if opt.compositional_eval:
                    controller = eval_split_input(e_src)
                    intermediates = []
                    comp_failure = False
                    for controller_input in controller:
                        if len(controller_input) == 1:
                            controller_src = controller_input[0]

                        else:
                            controller_src = ''
                            for src_index in range(len(controller_input) - 1):
                                controller_src += intermediates[controller_input[src_index]] + ' @@SEP@@ '
                            controller_src += controller_input[-1]
                            controller_src = remove_whitespace(controller_src)

                        indexed = []
                        sentence = SRC.preprocess(controller_src)
                        for tok in sentence:
                            if SRC.vocab.stoi[tok] != 0:
                                indexed.append(SRC.vocab.stoi[tok])
                            else:
                                comp_failure = True
                                break
                        if comp_failure:
                            break

                        sentence = Variable(torch.LongTensor([indexed]))
                        if opt.device == 0:
                            sentence = sentence.cuda()

                        try:
                            sentence = beam_search(sentence, model, SRC, TRG, opt)
                            intermediates.append(sentence)
                        except Exception as e:
                            comp_failure = True
                            break

                    if not comp_failure:
                        try:
                            test_acc += simple_em(sentence, e_tgt)
                            test_success += 1
                            test_predictions += sentence + '\n'
                        except Exception as e:
                            test_predictions += '\n'
                            continue
                    else:
                        test_predictions += '\n'
                else:
                    indexed = []
                    sentence = SRC.preprocess(e_src)
                    pass_bool = False
                    for tok in sentence:
                        if SRC.vocab.stoi[tok] != 0:
                            indexed.append(SRC.vocab.stoi[tok])
                        else:
                            pass_bool = True
                            break
                    if pass_bool:
                        continue

                    sentence = Variable(torch.LongTensor([indexed]))
                    if opt.device == 0:
                        sentence = sentence.cuda()

                    try:
                        sentence = beam_search(sentence, model, SRC, TRG, opt)
                    except Exception as e:
                        continue
                    try:
                        test_acc += simple_em(sentence, e_tgt)
                        test_success += 1
                        test_predictions += sentence + '\n'
                    except Exception as e:
                        test_predictions += '\n'
                        continue


            if test_success == 0:
                test_success = 1
            test_acc = test_acc / test_success
            print('test accuracy:', round(test_acc * 100, 2))
            print()

            if not os.path.exists(opt.output_dir):
                os.makedirs(opt.output_dir)

            with open(opt.output_dir + '/test_generations.txt', 'w', encoding='utf-8') as f:
                f.write(test_predictions)
Example #17
0
def train_model(model, opt):  # model = NaiveModel, Transformer or Seq2Seq
    val_loss_list = []
    early_stopping_epochs = []
    print("training model...")
    start = time.time()
    opt.start = start
    if opt.checkpoint > 0:
        cptime = time.time()

    criterion = nn.CrossEntropyLoss(
        ignore_index=opt.trg_pad)  # optional (new way)
    for epoch in range(opt.epochs):
        opt.epoch = epoch
        model.train()
        total_loss = 0

        print("   %dm: epoch %d [%s]  %d%%  loss = %s | valid_loss = %s" %\
        ((time.time() - start)//60, epoch + 1, "".join(' '*20), 0, '...', '...'), end='\r')

        if opt.checkpoint > 0:
            torch.save(model.state_dict(), 'weights/model_weights')

        batch_number = get_len(opt.train)
        opt.printevery = batch_number
        print(f"Epoch has {batch_number} batches.")
        if opt.nmt_model_type == 'transformer':
            for i, batch in tqdm(enumerate(
                    opt.train)):  # opt.train = MyIterator

                # [opt.SRC.vocab.itos[i] for i in batch.src[:, 0]] # to query batch words from field
                # ----- OLD WAY ------ #
                # src = batch.src.transpose(0,1)
                # trg = batch.trg.transpose(0,1)
                # trg_input = trg[:, :-1]
                # src_mask, trg_mask = create_masks(src, trg_input, opt)
                # preds = model(src, trg_input, src_mask, trg_mask) # -> [batch_size, sent_len, emb_dim]
                # ys = trg[:, 1:].contiguous().view(-1) # [batch_size * sent_len]
                # opt.optimizer.zero_grad()
                # loss = F.cross_entropy(preds.view(-1, preds.size(-1)), ys, ignore_index=opt.trg_pad)
                # -------------------- #
                # NEW WAY
                src = batch.src.transpose(
                    0, 1)  # do we really need the transpose?
                trg = batch.trg.transpose(
                    0, 1)  # do we really need the transpose?
                trg_input = trg[:, :-1]
                src_mask, trg_mask = create_masks(src, trg_input, opt)
                opt.optimizer.zero_grad()
                output = model(src, trg_input, src_mask,
                               trg_mask)  # -> [batch_size, sent_len, emb_dim]
                output = output.contiguous().view(-1, output.shape[-1])
                trg = trg[:, 1:].contiguous().view(-1)
                loss = criterion(output, trg)
                # else:
                #     src = batch.src
                #     trg = batch.trg
                #     opt.optimizer.zero_grad()
                #     output = model(src, trg)
                #     output = output[1:].view(-1, output.shape[-1])
                #     trg = trg[1:].view(-1)
                #     loss = criterion(output, trg)
                loss.backward()
                opt.optimizer.step()
                total_loss += loss.item()
                if (i + 1) % opt.printevery == 0:
                    p = int(100 * (i + 1) / opt.train_len)
                    avg_train_loss = total_loss / opt.printevery
                    print("   %dm: epoch %d [%s%s]  %d%%  loss = %.3f" %\
                    ((time.time() - start)//60, epoch + 1, "".join('#'*(p//5)), "".join(' '*(20-(p//5))), p, avg_train_loss), end='\r')
                    total_loss = 0
        else:
            total_loss, avg_train_loss = train_rnn(model, opt.train,
                                                   opt.optimizer, criterion, 1,
                                                   opt)

        if opt.SGDR == True:
            opt.sched.step()

        if opt.checkpoint > 0 and (
            (time.time() - cptime) // 60) // opt.checkpoint >= 1:
            torch.save(model.state_dict(), 'weights/model_weights')
            cptime = time.time()

        avg_valid_loss = evaluate(model, opt.valid, criterion, opt)
        val_loss_list.append(math.exp(avg_valid_loss))
        early_stop_flag = early_stopping_criterion(val_loss_list)
        if early_stop_flag == True:
            early_stopping_epochs.append(epoch)
            print(
                f"\nModel hasn't improved for 5 epochs. This happened {len(early_stopping_epochs)} times.\n"
            )
            early_stop_flag = early_stopping_criterion(val_loss_list,
                                                       window_size=7)
            if early_stop_flag == True:
                print(
                    f"\nModel hasn't improved for 7 epochs, terminating train...\n"
                )
                break
        epoch_mins = round((time.time() - start) // 60)
        bar_begin = "".join('#' * (100 // 5))
        bar_end = "".join(' ' * (20 - (100 // 5)))
        print(
            f'{epoch_mins:d}m: Epoch{epoch + 1} [{bar_begin}{bar_end}] {100}%')
        print(
            f'Train Loss: {avg_train_loss:.3f} | Train PPL: {math.exp(avg_train_loss):7.3f}'
        )
        print(
            f'Val. Loss: {avg_valid_loss:.3f}  | Val. PPL: {math.exp(avg_valid_loss):7.3f}'
        )
Example #18
0
def train_model(model, opt, SRC, TRG):
    print("training model...")

    model.train()
    start = time.time()
    mask_prob = opt.mask_prob
    if opt.checkpoint > 0:
        cptime = time.time()

    for epoch in range(opt.epochs):
        model.train()
        total_loss = 0
        avg_loss = 1e5
        print("   %dm: epoch %d [%s]  %d%%  loss = %s" %
              ((time.time() - start) // 60, epoch + 1, "".join(
                  ' ' * 20), 0, '...'),
              end='\r')

        if opt.checkpoint > 0:
            torch.save(model.state_dict(), 'weights/model_weights')

        for i, batch in enumerate(opt.train):
            src = batch.src.transpose(0, 1).cuda()
            real_trg = batch.trg.transpose(0, 1).cuda()
            bs = src.shape[0]
            add_pad = math.ceil(random.random() * 3)

            masked = False
            if opt.task == 'e_snli_o':
                trg = real_trg
                trg_input = trg[:, :-1]
                src_mask, trg_mask = create_masks(src, trg_input, opt)
            elif random.random() > mask_prob:
                trg = torch.cat((real_trg, torch.ones(
                    (bs, add_pad)).type(torch.LongTensor).cuda()),
                                dim=1)
                real_trg = torch.cat((real_trg, torch.ones(
                    (bs, add_pad)).type(torch.LongTensor).cuda()),
                                     dim=1)
                trg_input = trg[:, :-1]
                src_mask, trg_mask = create_masks(src, trg_input, opt)
            else:
                masked = True
                trg = torch.ones_like(real_trg).type(torch.LongTensor).cuda()
                trg[:, 0] = trg[:, 0] * 2
                trg = torch.cat((trg, torch.ones(
                    (bs, add_pad)).type(torch.LongTensor).cuda()),
                                dim=1)
                real_trg = torch.cat((real_trg, torch.ones(
                    (bs, add_pad)).type(torch.LongTensor).cuda()),
                                     dim=1)
                trg_input = trg[:, :-1]
                src_mask, trg_mask = create_hard_masks(src, trg_input, opt)

            if opt.task == 'e_snli_o':
                preds = model(src, src_mask)
            else:
                preds = model(src, trg_input, src_mask, trg_mask)

            # for non-classifier tasks:
            ys = real_trg[:, 1:].contiguous().view(-1)

            opt.optimizer.zero_grad()
            if opt.task == 'e_snli_o':
                loss = F.cross_entropy(preds.view(-1, preds.size(-1)),
                                       trg.contiguous().view(-1),
                                       ignore_index=opt.trg_pad)
            else:
                if masked:
                    peaked_soft = torch.exp(opt.alpha *
                                            F.softmax(preds, dim=-1))
                    peaked_soft_sum = torch.sum(peaked_soft,
                                                dim=-1).unsqueeze(2)
                    new_preds = torch.div(peaked_soft, peaked_soft_sum)
                    loss = F.cross_entropy(new_preds.view(-1, preds.size(-1)),
                                           ys,
                                           ignore_index=opt.trg_pad)
                else:
                    loss = F.cross_entropy(preds.view(-1, preds.size(-1)),
                                           ys,
                                           ignore_index=opt.trg_pad)
            loss.backward()
            opt.optimizer.step()

            if opt.wandb:
                if i % opt.log_interval == 0:
                    wandb.log({"loss": loss})

            if opt.SGDR:
                opt.sched.step()

            total_loss += loss.item()

            if (i + 1) % opt.printevery == 0:
                p = int(100 * (i + 1) / opt.train_len)
                avg_loss = total_loss / opt.printevery
                if opt.floyd is False:
                    print("   %dm: epoch %d [%s%s]  %d%%  loss = %.3f" % \
                          ((time.time() - start) // 60, epoch + 1, "".join('#' * (p // 5)),
                           "".join(' ' * (20 - (p // 5))), p, avg_loss), end='\r')
                else:
                    print("   %dm: epoch %d [%s%s]  %d%%  loss = %.3f" % \
                          ((time.time() - start) // 60, epoch + 1, "".join('#' * (p // 5)),
                           "".join(' ' * (20 - (p // 5))), p, avg_loss))
                total_loss = 0

            if opt.checkpoint > 0 and (
                (time.time() - cptime) // 60) // opt.checkpoint >= 1:
                torch.save(model.state_dict(), 'weights/model_weights')
                cptime = time.time()

        print("%dm: epoch %d [%s%s]  %d%%  loss = %.3f\nepoch %d complete, loss = %.03f" % \
              ((time.time() - start) // 60, epoch + 1, "".join('#' * (100 // 5)), "".join(' ' * (20 - (100 // 5))), 100,
               avg_loss, epoch + 1, avg_loss))

        if opt.calculate_val_loss:
            model.eval()
            val_losses = []
            for i, batch in enumerate(opt.val):
                src = batch.src.transpose(0, 1).cuda()
                trg = batch.trg.transpose(0, 1).cuda()
                trg_input = trg[:, :-1]
                src_mask, trg_mask = create_masks(src, trg_input, opt)
                preds = model(src, trg_input, src_mask, trg_mask)
                ys = trg[:, 1:].contiguous().view(-1)
                opt.optimizer.zero_grad()
                loss = F.cross_entropy(preds.view(-1, preds.size(-1)),
                                       ys,
                                       ignore_index=opt.trg_pad)
                val_losses.append(loss.item())

            print('validation loss:', sum(val_losses) / len(val_losses), '\n')

        if opt.val_forward_pass:
            model.eval()
            val_losses_no_eos = []

            if opt.task == 'toy_task':
                for i, batch in enumerate(opt.val):
                    src = batch.src.transpose(0, 1).cuda()
                    real_trg = batch.trg.transpose(0, 1).cuda()
                    trg = torch.ones_like(real_trg).type(
                        torch.LongTensor).cuda()
                    trg[:, 0] = trg[:, 0] * 2

                    bs = src.shape[0]
                    add_pad = math.ceil(random.random() * 3)
                    trg = torch.cat((trg, torch.ones(
                        (bs, add_pad)).type(torch.LongTensor).cuda()),
                                    dim=1)

                    trg_input = trg[:, :-1]
                    src_mask, trg_mask = create_hard_masks(src, trg_input, opt)
                    preds = model(src, trg_input, src_mask, trg_mask)
                    pred_tokens = torch.argmax(preds, dim=-1)
                    ys = real_trg[:, 1:]

                    for b_ in range(bs):
                        pred_tok = pred_tokens[b_]
                        y = ys[b_]

                        pad_index = ((y == 1).nonzero(as_tuple=True)[0]
                                     )  # 1 = pad in vocab
                        if pad_index.shape[0] == 0:
                            pad_index = y.shape[0]
                        else:
                            pad_index = pad_index[0]
                        if type(pad_index) != int:
                            pad_index = pad_index.item()

                        if torch.equal(pred_tok[:pad_index], y[:pad_index]):
                            val_losses_no_eos.append(1)
                        else:
                            val_losses_no_eos.append(0)
            elif opt.task == 'e_snli_r':
                val_label_accuracy = []
                for i, batch in enumerate(opt.val):
                    src = batch.src.transpose(0, 1).cuda()
                    real_trg1 = batch.trg1.transpose(0, 1).cuda()
                    real_trg2 = batch.trg2.transpose(0, 1).cuda()
                    real_trg3 = batch.trg3.transpose(0, 1).cuda()
                    labels = batch.label.transpose(0, 1).cuda()

                    bs = src.shape[0]
                    max_sl = max([
                        real_trg1.shape[1], real_trg2.shape[1],
                        real_trg3.shape[1]
                    ])

                    trg = torch.ones(
                        (bs, max_sl)).type(torch.LongTensor).cuda()
                    trg[:, 0] = trg[:, 0] * 2

                    add_pad = math.ceil(random.random() * 3)
                    trg = torch.cat((trg, torch.ones(
                        (bs, add_pad)).type(torch.LongTensor).cuda()),
                                    dim=1)

                    trg_input = trg[:, :-1]
                    src_mask, trg_mask = create_hard_masks(src, trg_input, opt)
                    preds = model(src, trg_input, src_mask, trg_mask)
                    pred_tokens = torch.argmax(preds, dim=-1)
                    ys1 = real_trg1[:, 1:]
                    ys2 = real_trg2[:, 1:]
                    ys3 = real_trg3[:, 1:]

                    for b_ in range(bs):
                        pred_tok = pred_tokens[b_]
                        y1 = ys1[b_]
                        y2 = ys2[b_]
                        y3 = ys3[b_]

                        correct = False
                        for y in [y1, y2, y3]:
                            pad_index = ((y == 1).nonzero(as_tuple=True)[0]
                                         )  # 1 = pad in vocab
                            if pad_index.shape[0] == 0:
                                pad_index = y.shape[0]
                            else:
                                pad_index = pad_index[0]
                            if type(pad_index) != int:
                                pad_index = pad_index.item()

                            if torch.equal(pred_tok[:pad_index],
                                           y[:pad_index]):
                                correct = True

                        if correct:
                            val_losses_no_eos.append(1)
                        else:
                            val_losses_no_eos.append(0)

                        rationale = []
                        for t in pred_tok:
                            word = TRG.vocab.itos[t]
                            if word == '<eos>':
                                break  # TODO: take out for differentiable version
                            if word in opt.classifier_SRC.vocab.stoi.keys():
                                rationale.append(
                                    opt.classifier_SRC.vocab.stoi[word])
                            else:
                                rationale.append(
                                    opt.classifier_SRC.vocab.stoi['<unk>'])

                        rationale = torch.Tensor([rationale]).type(
                            torch.LongTensor).cuda()
                        pred_label = classify(rationale, opt.classifier,
                                              opt.classifier_SRC,
                                              opt.classifier_TRG)

                        if pred_label == opt.classifier_TRG.vocab.itos[
                                labels[b_]]:  # latter is true label
                            val_label_accuracy.append(1)
                        else:
                            val_label_accuracy.append(0)
            else:
                raise NotImplementedError(
                    "No validation accuracy support for CoS-E or any other tasks yet."
                )

            if opt.wandb:
                wandb.log({
                    'validation forward accuracy':
                    round(
                        sum(val_losses_no_eos) / len(val_losses_no_eos) * 100,
                        2)
                })
                if opt.task == 'e_snli_r':
                    wandb.log({
                        'validation forward label accuracy':
                        round(
                            sum(val_label_accuracy) / len(val_label_accuracy) *
                            100, 2)
                    })
                    print(
                        'validation forward label accuracy:',
                        round(
                            sum(val_label_accuracy) / len(val_label_accuracy) *
                            100, 2))
            print(
                'validation forward accuracy:',
                round(
                    sum(val_losses_no_eos) / len(val_losses_no_eos) * 100, 2),
                '%')

        if (epoch + 1) % opt.val_check_every_n == 0:
            model.eval()
            val_acc, val_success = 0, 0
            val_data = zip_io_data(opt.data_path + '/val')

            if opt.task == 'toy_task':
                for j, e in enumerate(val_data[:opt.n_val]):
                    e_src, e_tgt = e[0], e[1]

                    if opt.compositional_eval:
                        controller = eval_split_input(e_src)
                        intermediates = []
                        comp_failure = False
                        for controller_input in controller:
                            if len(controller_input) == 1:
                                controller_src = controller_input[0]

                            else:
                                controller_src = ''
                                for src_index in range(
                                        len(controller_input) - 1):
                                    controller_src += intermediates[
                                        controller_input[
                                            src_index]] + ' @@SEP@@ '
                                controller_src += controller_input[-1]
                                controller_src = remove_whitespace(
                                    controller_src)

                            indexed = []
                            sentence = SRC.preprocess(controller_src)
                            for tok in sentence:
                                if SRC.vocab.stoi[tok] != 0:
                                    indexed.append(SRC.vocab.stoi[tok])
                                else:
                                    comp_failure = True
                                    break
                            if comp_failure:
                                break

                            sentence = Variable(torch.LongTensor([indexed]))
                            if opt.device == 0:
                                sentence = sentence.cuda()

                            try:
                                sentence = beam_search(sentence, model, SRC,
                                                       TRG, opt)
                                intermediates.append(sentence)
                            except Exception as e:
                                comp_failure = True

                                break

                        if not comp_failure:
                            try:
                                val_acc += simple_em(intermediates[-1], e_tgt)
                                val_success += 1
                            except Exception as e:
                                continue
                    else:
                        sentence = SRC.preprocess(e_src)
                        indexed = [SRC.vocab.stoi[tok] for tok in sentence]

                        sentence = Variable(torch.LongTensor([indexed]))
                        if opt.device == 0:
                            sentence = sentence.cuda()

                        try:
                            sentence = beam_search(sentence, model, SRC, TRG,
                                                   opt)
                        except Exception as e:
                            continue
                        try:
                            val_acc += simple_em(sentence, e_tgt)
                            val_success += 1
                        except Exception as e:
                            continue
            elif opt.task == 'e_snli_r':
                val_labels = zip_io_data(opt.label_path + '/val')
                beam_label_acc = []
                for j, e in enumerate(val_data[:opt.n_val]):
                    e_src, e_tgt = e[0], e[1]
                    e_label = val_labels[j][1]

                    sentence = SRC.preprocess(e_src)
                    indexed = [SRC.vocab.stoi[tok] for tok in sentence]

                    sentence = Variable(torch.LongTensor([indexed]))
                    if opt.device == 0:
                        sentence = sentence.cuda()

                    try:
                        sentence = beam_search(sentence, model, SRC, TRG, opt)
                    except Exception as e:
                        continue

                    candidate_trgs = e_tgt.split(' @@SEP@@ ')
                    correct = False
                    try:
                        for candidate in candidate_trgs:
                            if simple_em(sentence, candidate):
                                correct = True
                    except Exception as e:
                        continue

                    if correct:
                        val_acc += 1
                    val_success += 1

                    rationale = []
                    for word in sentence.split():
                        if word in opt.classifier_SRC.vocab.stoi.keys():
                            rationale.append(
                                opt.classifier_SRC.vocab.stoi[word])
                        else:
                            rationale.append(
                                opt.classifier_SRC.vocab.stoi['<unk>'])

                    rationale = torch.Tensor([rationale
                                              ]).type(torch.LongTensor).cuda()
                    pred_label = classify(rationale, opt.classifier,
                                          opt.classifier_SRC,
                                          opt.classifier_TRG)

                    if pred_label == e_label:  # latter is true label
                        beam_label_acc.append(1)
                    else:
                        beam_label_acc.append(0)

            elif opt.task == 'e_snli_o':
                for j, e in enumerate(val_data[:opt.n_val]):
                    e_src, e_tgt = e[0], e[1]
                    sentence = SRC.preprocess(e_src)
                    indexed = [SRC.vocab.stoi[tok] for tok in sentence]

                    sentence = Variable(torch.LongTensor([indexed]))
                    if opt.device == 0:
                        sentence = sentence.cuda()

                    try:
                        sentence = classify(sentence, model, SRC, TRG)
                    except Exception as e:
                        continue

                    if simple_em(sentence, e_tgt):
                        val_acc += 1
                    val_success += 1
            else:
                raise NotImplementedError(
                    "no beam search support for CoS-E or other tasks")

            if val_success == 0:
                val_success = 1
            val_acc = val_acc / val_success
            print('validation beam accuracy:', round(val_acc * 100, 2))
            if opt.wandb:
                wandb.log(
                    {'validation beam accuracy': round(val_acc * 100, 2)})
                if opt.task == 'e_snli_r':
                    wandb.log({
                        'validation beam label accuracy':
                        round(
                            sum(beam_label_acc) / len(beam_label_acc) * 100, 2)
                    })
                    print(
                        'validation beam label accuracy:',
                        round(
                            sum(beam_label_acc) / len(beam_label_acc) * 100,
                            2))
            opt.scheduler.step(val_acc)
            print('-' * 10, '\n')

        if epoch == opt.epochs - 1 and opt.do_test:
            model.eval()
            test_data = zip_io_data(opt.data_path + '/test')
            test_beam_predictions = ''
            test_fwd_predictions = ''
            beam_acc, fwd_acc = [], []
            beam_label_acc, fwd_label_acc = [], []

            if not os.path.exists(opt.output_dir):
                os.makedirs(opt.output_dir)

            if opt.task == 'e_snli_r':
                test_labels = zip_io_data(opt.label_path + '/test')

            for j, e in enumerate(test_data[:opt.n_test]):
                if (j + 1) % 10000 == 0:
                    print(round(j / len(test_data) * 100, 2),
                          '% complete with testing')
                e_src, e_tgt = e[0], e[1]
                if opt.task == 'e_snli_r':
                    e_label = test_labels[j][1]

                indexed = []
                sentence = SRC.preprocess(e_src)
                pass_bool = False
                for tok in sentence:
                    if SRC.vocab.stoi[tok] != 0:
                        indexed.append(SRC.vocab.stoi[tok])
                    else:
                        pass_bool = True
                        break
                if pass_bool:
                    continue

                sentence = Variable(torch.LongTensor([indexed]))
                if opt.device == 0:
                    sentence = sentence.cuda()

                if opt.val_forward_pass:
                    src = sentence
                    trg = torch.ones(
                        (1, opt.max_strlen)).type(torch.LongTensor).cuda()
                    trg[:, 0] = trg[:, 0] * 2

                    trg_input = trg[:, :-1]
                    src_mask, trg_mask = create_hard_masks(src, trg_input, opt)
                    preds = model(src, trg_input, src_mask, trg_mask)
                    pred_tokens = torch.argmax(preds, dim=-1)
                    ys = [TRG.vocab.stoi[tok] for tok in e_tgt.split()
                          ] + [3]  # TODO: remove hardcode of EOS (3)
                    pred_tok = pred_tokens[0].tolist()

                    if pred_tok[:len(ys)] == ys:
                        fwd_acc.append(1)
                    else:
                        fwd_acc.append(0)

                    pred_nl = ' '.join(
                        [TRG.vocab.itos[tok] for tok in pred_tok])
                    if ' <eos>' in pred_nl:
                        pred_nl = pred_nl[:pred_nl.index(
                            ' <eos>'
                        )]  # TODO: take this out for differentiable version
                    if ' .' in pred_nl:
                        pred_nl = pred_nl[:pred_nl.index(
                            ' .'
                        ) + 2]  # TODO: take this out for differentiable version
                    test_fwd_predictions += pred_nl + '\n'

                    if opt.task == 'e_snli_r':
                        rationale = []
                        for word in pred_nl.split():
                            if word in opt.classifier_SRC.vocab.stoi.keys():
                                rationale.append(
                                    opt.classifier_SRC.vocab.stoi[word])
                            else:
                                rationale.append(
                                    opt.classifier_SRC.vocab.stoi['<unk>'])

                        rationale = torch.Tensor([rationale]).type(
                            torch.LongTensor).cuda()
                        pred_label = classify(rationale, opt.classifier,
                                              opt.classifier_SRC,
                                              opt.classifier_TRG)

                        if pred_label == e_label:  # latter is true label
                            fwd_label_acc.append(1)
                        else:
                            fwd_label_acc.append(0)

                try:
                    sentence = beam_search(sentence, model, SRC, TRG, opt)
                except Exception as e:
                    continue
                try:
                    beam_acc.append(simple_em(sentence, e_tgt))
                    test_beam_predictions += sentence + '\n'
                except Exception as e:
                    test_beam_predictions += '\n'
                    continue

                if opt.task == 'e_snli_r':
                    rationale = []
                    for word in sentence.split():
                        if word in opt.classifier_SRC.vocab.stoi.keys():
                            rationale.append(
                                opt.classifier_SRC.vocab.stoi[word])
                        else:
                            rationale.append(
                                opt.classifier_SRC.vocab.stoi['<unk>'])

                    rationale = torch.Tensor([rationale
                                              ]).type(torch.LongTensor).cuda()
                    pred_label = classify(rationale, opt.classifier,
                                          opt.classifier_SRC,
                                          opt.classifier_TRG)

                    if pred_label == e_label:
                        beam_label_acc.append(1)
                    else:
                        beam_label_acc.append(0)

            # beam search logging
            if opt.wandb:
                wandb.log({
                    'test beam accuracy':
                    round(sum(beam_acc) / len(beam_acc) * 100, 2)
                })
            print('test beam accuracy:',
                  round(sum(beam_acc) / len(beam_acc) * 100, 2))
            with open(opt.output_dir + '/test_beam_generations.txt',
                      'w',
                      encoding='utf-8') as f:
                f.write(test_beam_predictions)

            # fwd pass logging
            if opt.val_forward_pass:
                print('test forward accuracy:',
                      round(sum(fwd_acc) / len(fwd_acc) * 100, 2), '%')
                if opt.wandb:
                    wandb.log({
                        'test forward accuracy':
                        round(sum(fwd_acc) / len(fwd_acc) * 100, 2)
                    })
                with open(opt.output_dir + '/test_fwd_generations.txt',
                          'w',
                          encoding='utf-8') as f:
                    f.write(test_fwd_predictions)

            # e_snli_r logging
            if opt.task == 'e_snli_r':
                if opt.wandb:
                    wandb.log({
                        'test beam label accuracy':
                        round(
                            sum(beam_label_acc) / len(beam_label_acc) * 100, 2)
                    })
                print(
                    'test beam label accuracy:',
                    round(sum(beam_label_acc) / len(beam_label_acc) * 100, 2))

                if opt.val_forward_pass:
                    if opt.wandb:
                        wandb.log({
                            'test forward label accuracy':
                            round(
                                sum(fwd_label_acc) / len(fwd_label_acc) * 100,
                                2)
                        })
                    print(
                        'test forward label accuracy:',
                        round(
                            sum(fwd_label_acc) / len(fwd_label_acc) * 100, 2))
Example #19
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-data_path', required=True)
    parser.add_argument('-output_dir', required=True)
    parser.add_argument('-no_cuda', action='store_true')
    parser.add_argument('-SGDR', action='store_true')
    parser.add_argument('-val_check_every_n', type=int, default=3)
    parser.add_argument('-calculate_val_loss', action='store_true')
    parser.add_argument('-val_forward_pass', action='store_true')
    parser.add_argument('-tensorboard_graph', action='store_true')
    parser.add_argument('-alex', action='store_true')
    parser.add_argument('-compositional_eval', action='store_true')
    parser.add_argument('-wandb', action='store_true')
    parser.add_argument('-n_val', type=int, default=1000)
    parser.add_argument('-n_test', type=int, default=1000)
    parser.add_argument('-do_test', action='store_true')
    parser.add_argument('-epochs', type=int, default=50)
    parser.add_argument('-d_model', type=int, default=512)
    parser.add_argument('-n_layers', type=int, default=6)
    parser.add_argument('-heads', type=int, default=8)
    parser.add_argument('-dropout', type=int, default=0.1)
    parser.add_argument('-mask_prob', type=float, default=0.5)
    parser.add_argument('-alpha', type=float, default=2)
    parser.add_argument('-batchsize', type=int, default=3000)
    parser.add_argument('-printevery', type=int, default=100)
    parser.add_argument('-log_interval', type=int, default=1000)
    parser.add_argument('-lr', type=int, default=0.0001)
    parser.add_argument('-load_weights')
    parser.add_argument('-load_r_to_o')
    parser.add_argument('-label_path')
    parser.add_argument('-create_valset', action='store_true')
    parser.add_argument('-task',
                        type=str,
                        choices=["toy_task", "e_snli_r", "e_snli_o", "cos_e"],
                        default="toy_task")
    parser.add_argument('-max_strlen', type=int, default=512)
    parser.add_argument('-floyd', action='store_true')
    parser.add_argument('-checkpoint', type=int, default=0)

    opt = parser.parse_args()

    wandb_tags = [opt.task]

    opt.device = 0 if opt.no_cuda is False else -1
    if opt.device == 0:
        assert torch.cuda.is_available()
        if opt.alex:
            torch.cuda.set_device(1)

    read_data(opt)

    if opt.task == 'e_snli_r':
        assert opt.label_path is not None
        opt.classifier_SRC, opt.classifier_TRG = create_label_fields(opt)

        with open(opt.load_r_to_o + '/SRC.pkl', 'rb') as f:
            old_SRC = pickle.load(f)

        with open(opt.load_r_to_o + '/TRG.pkl', 'rb') as f:
            old_TRG = pickle.load(f)

        opt.classifier_SRC.vocab = old_SRC.vocab
        opt.classifier_TRG.vocab = old_TRG.vocab

    SRC, TRG = create_fields(opt)
    opt.train, opt.val = create_dataset(opt, SRC, TRG)
    if opt.task == 'e_snli_o':
        model = get_classifier_model(opt, len(SRC.vocab), len(TRG.vocab))
    else:
        if opt.task == 'e_snli_r':
            opt.classifier = load_r_to_o(opt, len(opt.classifier_SRC.vocab),
                                         len(opt.classifier_TRG.vocab))
        model = get_model(opt, len(SRC.vocab), len(TRG.vocab), SRC)

    if opt.wandb:
        config = wandb.config
        config.learning_rate = opt.lr
        config.max_pred_length = opt.max_strlen
        config.mask_prob = opt.mask_prob
        config.batch_size = opt.batchsize
        config.log_interval = opt.log_interval
        group_name = 'masking_probability_p=' + str(
            opt.mask_prob) + '_alpha=' + str(opt.alpha)

        wandb.init(config=config,
                   project='toy-task',
                   entity='c-col',
                   group=group_name,
                   tags=wandb_tags)
        wandb.watch(model)

    if opt.tensorboard_graph:
        writer = SummaryWriter('runs')
        for i, batch in enumerate(opt.train):
            src = batch.src.transpose(0, 1).cuda()
            trg = batch.trg.transpose(0, 1).cuda()
            trg_input = trg[:, :-1]
            src_mask, trg_mask = create_masks(src, trg_input, opt)
            writer.add_graph(model, (src, trg_input, src_mask, trg_mask))
            break
        writer.close()

    # beam search parameters
    opt.k = 1
    opt.max_len = opt.max_strlen

    opt.optimizer = torch.optim.Adam(model.parameters(),
                                     lr=opt.lr,
                                     betas=(0.9, 0.98),
                                     eps=1e-9)
    opt.scheduler = ReduceLROnPlateau(opt.optimizer,
                                      factor=0.5,
                                      patience=5,
                                      verbose=True)

    if opt.SGDR:
        opt.sched = CosineWithRestarts(opt.optimizer, T_max=opt.train_len)

    if opt.checkpoint > 0:
        print(
            "model weights will be saved every %d minutes and at end of epoch to directory weights/"
            % (opt.checkpoint))

    train_model(model, opt, SRC, TRG)

    if opt.floyd is False:
        promptNextAction(model, opt, SRC, TRG)
                操作序列 = np.append(操作序列, 抽样np[0, 0])

            else:

                img = np.array(imgA)

                img = torch.from_numpy(img).cuda(device).unsqueeze(0).permute(
                    0, 3, 2, 1) / 255
                _, out = resnet101(img)
                图片张量 = 图片张量[0:18, :]
                操作序列 = 操作序列[0:18]
                操作序列 = np.append(操作序列, 抽样np[0, 0])
                图片张量 = torch.cat((图片张量, out.reshape(1, 6 * 6 * 2048)), 0)

            操作张量 = torch.from_numpy(操作序列.astype(np.int64)).cuda(device)
            src_mask, trg_mask = create_masks(操作张量.unsqueeze(0),
                                              操作张量.unsqueeze(0), device)
            输出_实际_A = model(图片张量.unsqueeze(0), 操作张量.unsqueeze(0), trg_mask)

            LI = 操作张量.contiguous().view(-1)
            # LA=输出_实际_A.view(-1, 输出_实际_A.size(-1))
            if 计数 % 50 == 0 and 计数 != 0:

                设备.发送(购买)
                设备.发送(加三技能)
                设备.发送(加二技能)
                设备.发送(加一技能)
                设备.发送(操作查询词典['移动停'])
                print(旧指令, '周期')
                time.sleep(0.02)
                设备.发送(操作查询词典[旧指令])
Example #21
0
def train_model(model, opt):

    print("training model...")
    model.train()
    start = time.time()
    if opt.checkpoint > 0:
        cptime = time.time()
    epoch_nums = []
    align_losses = []
    lexical_losses = []
    for epoch in range(opt.epochs):

        total_loss = 0
        if opt.floyd is False:
            print("   %dm: epoch %d [%s]  %d%%  loss = %s" %\
            ((time.time() - start)//60, epoch + 1, "".join(' '*20), 0, '...'), end='\r')

        if opt.checkpoint > 0:
            torch.save(model.state_dict(), 'weights/model_weights')

        for i, batch in enumerate(opt.train):

            src = batch.src.transpose(0, 1).to(device=opt.device)
            trg = batch.trg.transpose(0, 1).to(device=opt.device)
            trg_input = trg[:, :-1]
            src_mask, trg_mask = create_masks(src, trg_input, opt)
            preds, align_loss = model(src, trg_input, src_mask, trg_mask)
            ys = trg[:, 1:].contiguous().view(-1)
            opt.optimizer.zero_grad()
            lexical_loss = F.cross_entropy(preds.view(-1, preds.size(-1)),
                                           ys,
                                           ignore_index=opt.trg_pad)

            epoch_nums.append(epoch)
            align_losses.append(align_loss)
            lexical_losses.append(lexical_loss)

            # Eq. 5 of paper
            lambda_ = 0.3
            loss = lexical_loss + lambda_ * align_loss

            loss.backward()
            opt.optimizer.step()
            if opt.SGDR == True:
                opt.sched.step()

            total_loss += loss.item()

            if (i + 1) % opt.printevery == 0:
                p = int(100 * (i + 1) / opt.train_len)
                avg_loss = total_loss / opt.printevery
                if opt.floyd is False:
                    print("   %dm: epoch %d [%s%s]  %d%%  loss = %.3f" %\
                    ((time.time() - start)//60, epoch + 1, "".join('#'*(p//5)), "".join(' '*(20-(p//5))), p, avg_loss), end='\r')
                else:
                    print("   %dm: epoch %d [%s%s]  %d%%  loss = %.3f" %\
                    ((time.time() - start)//60, epoch + 1, "".join('#'*(p//5)), "".join(' '*(20-(p//5))), p, avg_loss))
                total_loss = 0

            if opt.checkpoint > 0 and (
                (time.time() - cptime) // 60) // opt.checkpoint >= 1:
                torch.save(model.state_dict(), 'weights/model_weights')
                cptime = time.time()


        print("%dm: epoch %d [%s%s]  %d%%  loss = %.3f\nepoch %d complete, loss = %.03f" %\
        ((time.time() - start)//60, epoch + 1, "".join('#'*(100//5)), "".join(' '*(20-(100//5))), 100, avg_loss, epoch + 1, avg_loss))

    loss_history = torch.FloatTensor(
        [epoch_nums, align_losses, lexical_losses]).t()
    torch.save(loss_history, 'loss_history.pt')
Example #22
0
def train_model(model, opt):
    print("training model...")
    model.train()
    start = time.time()
    if opt.checkpoint > 0:
        cptime = time.time()

    for epoch in range(opt.epochs):
        epoch_start_time = time.time()
        total_loss = 0
        '''
    if opt.floyd is False:
      print("   %dm: epoch %d [%s]  %d%%  loss = %s" %\
      ((time.time() - start)//60, epoch + 1, "".join(' '*20), 0, '...'), end='\r')
    '''

        if opt.checkpoint > 0:
            torch.save(model.state_dict(), 'weights/model_weights')

        batch_time_sum = 0
        processed_batches = 0
        for i, batch in enumerate(opt.train):
            batch_start_time = time.time()
            src = batch.src.transpose(0, 1)
            trg = batch.trg.transpose(0, 1)
            trg_input = trg[:, :-1]
            src_mask, trg_mask = create_masks(src, trg_input, opt)
            preds = model(src, trg_input, src_mask, trg_mask)
            ys = trg[:, 1:].contiguous().view(-1)
            opt.optimizer.zero_grad()
            loss = F.cross_entropy(preds.view(-1, preds.size(-1)),
                                   ys,
                                   ignore_index=opt.trg_pad)
            loss.backward()
            opt.optimizer.step()
            if opt.SGDR == True:
                opt.sched.step()

            total_loss += loss.item()
            '''
      if (i + 1) % opt.printevery == 0:
           p = int(100 * (i + 1) / opt.train_len)
           avg_loss = total_loss/opt.printevery
           if opt.floyd is False:
              print("   %dm: epoch %d [%s%s]  %d%%  loss = %.3f" %\
              ((time.time() - start)//60, epoch + 1, "".join('#'*(p//5)), "".join(' '*(20-(p//5))), p, avg_loss), end='\r')
           else:
              print("   %dm: epoch %d [%s%s]  %d%%  loss = %.3f" %\
              ((time.time() - start)//60, epoch + 1, "".join('#'*(p//5)), "".join(' '*(20-(p//5))), p, avg_loss))
           total_loss = 0
      '''

            if opt.checkpoint > 0 and (
                (time.time() - cptime) // 60) // opt.checkpoint >= 1:
                torch.save(model.state_dict(), 'weights/model_weights')
                cptime = time.time()

            batch_time = time.time() - batch_start_time
            batch_time_sum += batch_time
            processed_batches += 1
            if (i % opt.printevery == 0) and i != 0:
                print("Batch {}/{}, average time: {:.6f}, loss: {:.6f}".format(
                    i, opt.train_len, batch_time_sum / opt.printevery,
                    total_loss / processed_batches))
                batch_time_sum = 0

        print("%dm: epoch %d [%s%s]  %d%%  loss = %.3f\nepoch %d complete, loss = %.03f" %\
        ((time.time() - start)//60, epoch + 1, "".join('#'*(100//5)), "".join(' '*(20-(100//5))), 100, avg_loss, epoch + 1, avg_loss))
Example #23
0
def main():

    global AI打开
    global 操作列
    加三技能 = '6'
    加二技能 = '5'
    加一技能 = '4'
    购买 = 'f1'
    词数词典路径 = "./json/词_数表.json"
    数_词表路径 = "./json/数_词表.json"
    操作查询路径 = "./json/名称_操作.json"
    操作词典 = {"图片号": "0", "移动操作": "无移动", "动作操作": "无动作"}
    th = threading.Thread(target=start_listen, )
    th.start()  # 启动线程

    if os.path.isfile(词数词典路径) and os.path.isfile(数_词表路径):
        词_数表, 数_词表 = 读出引索(词数词典路径, 数_词表路径)
    with open(词数词典路径, encoding='utf8') as f:
        词数词典 = json.load(f)
    with open(操作查询路径, encoding='utf8') as f:
        操作查询词典 = json.load(f)

    方向表 = ['上移', '下移', '左移', '右移', '左上移', '左下移', '右上移', '右下移']

    设备 = MyMNTDevice(_DEVICE_ID)
    device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
    # mod = torchvision.models.resnet101(pretrained=True).eval().cuda(device).requires_grad_(False)
    # mod = torchvision.models.resnet101(pretrained=True).eval().cpu().requires_grad_(False)
    # mod = torchvision.models.resnet50(pretrained=True).eval().cpu().requires_grad_(False)
    # mod = torchvision.models.resnet34(pretrained=True).eval().cpu().requires_grad_(False)
    mod = torchvision.models.resnet18(
        pretrained=True).eval().cpu().requires_grad_(False)
    resnet101 = myResnet(mod)
    config = TransformerConfig()

    model = get_model(config, 130, 模型名称)

    # model = model.cuda(device).requires_grad_(False)
    model = model.cpu().requires_grad_(False)
    抽样np = 0

    if AI打开:

        图片张量 = torch.Tensor(0)
        操作张量 = torch.Tensor(0)

        # 伪词序列 = torch.from_numpy(np.ones((1, 60)).astype(np.int64)).cuda(device).unsqueeze(0)
        伪词序列 = torch.from_numpy(np.ones(
            (1, 60)).astype(np.int64)).cpu().unsqueeze(0)

        操作序列 = np.ones((1, ))
        操作序列[0] = 128
        计数 = 0
        time_start = time.time()
        旧指令 = '移动停'
        for i in range(1000000):
            # logger.info("++++001+++++")
            if AI打开 == False:
                break
            try:
                imgA = 取图(窗口名称)
            except:
                AI打开 = False
                print('取图失败')
                break
            # logger.info("+++++002++++")
            计时开始 = time.time()

            # preprocess
            图片张量, 操作序列 = preprocess(图片张量, imgA, resnet101, 操作序列, 抽样np)

            pre_process = time.time()

            logger.info("pre_process : {} ms ".format(pre_process - 计时开始))

            # transform model
            # 操作张量 = torch.from_numpy(操作序列.astype(np.int64)).cuda(device)
            操作张量 = torch.from_numpy(操作序列.astype(np.int64)).cpu()
            src_mask, trg_mask = create_masks(操作张量.unsqueeze(0),
                                              操作张量.unsqueeze(0), device)
            输出_实际_A = model(图片张量.unsqueeze(0), 操作张量.unsqueeze(0), trg_mask)

            # logger.info("+++++003++++")
            LI = 操作张量.contiguous().view(-1)
            # LA=输出_实际_A.view(-1, 输出_实际_A.size(-1))
            if 计数 % 20 == 0 and 计数 != 0:
                print("jineng + zhuangbei ")
                设备.发送(购买)
                设备.发送(加三技能)
                设备.发送(加二技能)
                设备.发送(加一技能)
                设备.发送('移动停')
                logger.warning("{} {}".format(旧指令, '周期'))
                # print(旧指令, '周期')
                # time.sleep(0.02)
                设备.发送(旧指令)
            # logger.info("++++004+++++")
            if 计数 % 1 == 0:
                time_end = time.time()

                输出_实际_A = F.softmax(输出_实际_A, dim=-1)
                输出_实际_A = 输出_实际_A[:, -1, :]
                抽样 = torch.multinomial(输出_实际_A, num_samples=1)
                抽样np = 抽样.cpu().numpy()

                指令 = 数_词表[str(抽样np[0, -1])]
                指令集 = 指令.split('_')

                # 操作词典 = {"图片号": "0", "移动操作": "无移动", "动作操作": "无动作"}
                操作词典['图片号'] = str(i)
                方向结果 = 处理方向()
                # logger.info("++++005+++++")
                logger.info("方向结果:{} 操作列:{} 攻击态:{}".format(
                    方向结果, len(操作列), 攻击态))

                # deal with output
                操作列, output_suc = output(方向结果, 操作列, 操作词典, 指令集, 旧指令, 设备, imgA,
                                         i)

                if output_suc == 0:
                    AI打开 = False
                    break

                # logging
                # logger.info("++++008+++++")
                用时1 = 0.22 - (time.time() - 计时开始)
                if 用时1 > 0:
                    logger.info("++++sleep+++++")
                    time.sleep(用时1)
                    logger.info("+++++++++")

                用时 = time_end - time_start
                print("用时{} 第{}张 延时{}".format(用时, i, 用时1), 'A键按下', A键按下,
                      'W键按下', W键按下, 'S键按下', S键按下, 'D键按下', D键按下, '旧指令', 旧指令,
                      'AI打开', AI打开, '操作列', 操作列)

                计数 = 计数 + 1
                # logger.info("++++009+++++")

    记录文件.close()
    time.sleep(1)
    print('AI打开', AI打开)
Example #24
0
def train_model(model, opt, SRC, TRG):
    print("training model...")
    model.train()
    start = time.time()
    if opt.checkpoint > 0:
        cptime = time.time()

    for epoch in range(opt.epochs):
        model.train()
        total_loss = 0
        avg_loss = 1e5
        print("   %dm: epoch %d [%s]  %d%%  loss = %s" % ((time.time() - start) // 60, epoch + 1, "".join(' ' * 20), 0, '...'), end='\r')

        if opt.checkpoint > 0:
            torch.save(model.state_dict(), 'weights/model_weights')

        for i, batch in enumerate(opt.train):
            src = batch.src.transpose(0, 1).cuda()

            assert src.shape[0] == 1
            src_tokens = ' '.join(pred_to_vocab(SRC, src[0]))
            refs, steps = train_split_input(SRC, src_tokens)
            steps = [torch.LongTensor([step]).cuda() for step in steps]

            fake_trg = torch.ones((1, opt.max_strlen)).type(torch.LongTensor).cuda()
            fake_trg[:, 0] = fake_trg[:, 0] * 2
            real_trg = batch.trg.transpose(0, 1).cuda()
            fake_trg_input, real_trg_input = fake_trg[:, :-1], real_trg[:, :-1]
            _, fake_trg_mask = create_hard_masks(src, fake_trg_input, opt)
            _, real_trg_mask = create_masks(src, real_trg_input, opt)

            sep_tensor = model.encoder.embed(torch.LongTensor([model.sep_token]).cuda()).unsqueeze(0)
            decoder_embed = model.decoder.embed.get_weights()

            try:
                preds = model(sep_tensor, decoder_embed, steps, refs, fake_trg_input, real_trg_input, fake_trg_mask, real_trg_mask)
            except RuntimeError:
                continue
            ys = real_trg[:, 1:].contiguous().view(-1)

            opt.optimizer.zero_grad()
            loss = F.cross_entropy(preds.view(-1, preds.size(-1)), ys, ignore_index=opt.trg_pad)
            try:
                loss.backward()
            except RuntimeError:
                continue
            opt.optimizer.step()

            # print('success on step length', len(steps), '; token length', src.shape[1])
            if opt.SGDR:
                opt.sched.step()

            total_loss += loss.item()

            if (i + 1) % opt.printevery == 0:
                p = int(100 * (i + 1) / opt.train_len)
                avg_loss = total_loss / opt.printevery
                if opt.floyd is False:
                    print("   %dm: epoch %d [%s%s]  %d%%  loss = %.3f" % \
                          ((time.time() - start) // 60, epoch + 1, "".join('#' * (p // 5)),
                           "".join(' ' * (20 - (p // 5))), p, avg_loss), end='\r')
                else:
                    print("   %dm: epoch %d [%s%s]  %d%%  loss = %.3f" % \
                          ((time.time() - start) // 60, epoch + 1, "".join('#' * (p // 5)),
                           "".join(' ' * (20 - (p // 5))), p, avg_loss))
                total_loss = 0

            if opt.checkpoint > 0 and ((time.time() - cptime) // 60) // opt.checkpoint >= 1:
                torch.save(model.state_dict(), 'weights/model_weights')
                cptime = time.time()

        print("%dm: epoch %d [%s%s]  %d%%  loss = %.3f\nepoch %d complete, loss = %.03f" % \
              ((time.time() - start) // 60, epoch + 1, "".join('#' * (100 // 5)), "".join(' ' * (20 - (100 // 5))), 100,
               avg_loss, epoch + 1, avg_loss))

        if opt.calculate_val_loss:
            model.eval()
            val_losses = []
            for i, batch in enumerate(opt.val):
                src = batch.src.transpose(0, 1).cuda()
                trg = batch.trg.transpose(0, 1).cuda()
                trg_input = trg[:, :-1]
                src_mask, trg_mask = create_masks(src, trg_input, opt)
                preds = model(src, trg_input, src_mask, trg_mask)
                ys = trg[:, 1:].contiguous().view(-1)
                opt.optimizer.zero_grad()
                loss = F.cross_entropy(preds.view(-1, preds.size(-1)), ys, ignore_index=opt.trg_pad)
                val_losses.append(loss.item())

            print('validation loss:', sum(val_losses) / len(val_losses), '\n')

        if opt.val_forward_pass:
            model.eval()
            val_losses = []
            val_losses_no_eos = []
            val_eos_dict = {k: [0, 0] for k in range(1, 601)}
            val_pad_dict = {k: [0, 0] for k in range(1, 601)}
            for i, batch in enumerate(opt.val):
                src = batch.src.transpose(0, 1).cuda()
                # trg = batch.trg.transpose(0, 1).cuda()
                # TODO: confirm swap below
                real_trg = batch.trg.transpose(0, 1).cuda()
                trg = torch.ones_like(real_trg).type(torch.LongTensor).cuda()
                trg[:, 0] = trg[:, 0] * 2

                bs = src.shape[0]
                add_pad = math.ceil(random.random() * 3)
                trg = torch.cat((trg, torch.ones((bs, add_pad)).type(torch.LongTensor).cuda()), dim=1)

                trg_input = trg[:, :-1]
                src_mask, trg_mask = create_hard_masks(src, trg_input, opt)
                preds = model(src, trg_input, src_mask, trg_mask)
                pred_tokens = torch.argmax(preds, dim=-1)
                # ys = trg[:, 1:] TODO: was swapped with real
                ys = real_trg[:, 1:]

                for b_ in range(bs):
                    pred_tok = pred_tokens[b_]
                    y = ys[b_]
                    sl = y.shape[0]

                    eos_index = ((y == 3).nonzero(as_tuple=True)[0])[0]  # 3 = eos in vocab
                    if type(eos_index) != int:
                        eos_index = eos_index.item()

                    if torch.equal(pred_tok[:eos_index], y[:eos_index]):
                        val_losses.append(1)
                        if eos_index in val_eos_dict.keys():  # add to seq length counter
                            val_eos_dict[eos_index][0] += 1
                    else:
                        val_losses.append(0)

                    if eos_index in val_eos_dict.keys():
                        val_eos_dict[sl][1] += 1

                    pad_index = ((y == 1).nonzero(as_tuple=True)[0])  # 1 = pad in vocab
                    if pad_index.shape[0] == 0:
                        pad_index = y.shape[0]
                    else:
                        pad_index = pad_index[0]
                    if type(pad_index) != int:
                        pad_index = pad_index.item()

                    if torch.equal(pred_tok[:pad_index], y[:pad_index]):
                        val_losses_no_eos.append(1)
                        if pad_index in val_pad_dict.keys():  # add to seq length counter
                            val_pad_dict[pad_index][0] += 1
                    else:
                        val_losses_no_eos.append(0)

                    if sl in val_pad_dict.keys():
                        val_pad_dict[pad_index][1] += 1

            print('forward pass validation accuracy - no eos:',
                  round(sum(val_losses) / len(val_losses) * 100, 2), '%')
            print('forward pass validation accuracy - no pad:',
                  round(sum(val_losses_no_eos) / len(val_losses_no_eos) * 100, 2), '%')

        if (epoch + 1) % opt.val_check_every_n == 0:
            model.eval()
            val_acc, val_success = 0, 0
            val_data = zip_io_data(opt.data_path + '/val')
            for j, e in enumerate(val_data[:opt.n_val]):
                e_src, e_tgt = e[0], e[1]

                if opt.compositional_eval:
                    controller = eval_split_input(e_src)
                    intermediates = []
                    comp_failure = False
                    for controller_input in controller:
                        if len(controller_input) == 1:
                            controller_src = controller_input[0]

                        else:
                            controller_src = ''
                            for src_index in range(len(controller_input) - 1):
                                controller_src += intermediates[controller_input[src_index]] + ' @@SEP@@ '
                            controller_src += controller_input[-1]
                            controller_src = remove_whitespace(controller_src)

                        indexed = []
                        sentence = SRC.preprocess(controller_src)
                        for tok in sentence:
                            if SRC.vocab.stoi[tok] != 0:
                                indexed.append(SRC.vocab.stoi[tok])
                            else:
                                comp_failure = True
                                break
                        if comp_failure:
                            break

                        sentence = Variable(torch.LongTensor([indexed]))
                        if opt.device == 0:
                            sentence = sentence.cuda()

                        try:
                            sentence = beam_search(sentence, model, SRC, TRG, opt)
                            intermediates.append(sentence)
                        except Exception as e:
                            comp_failure = True

                            break

                    if not comp_failure:
                        try:
                            val_acc += simple_em(intermediates[-1], e_tgt)
                            val_success += 1
                        except Exception as e:
                            continue
                else:
                    sentence = SRC.preprocess(e_src)
                    indexed = [SRC.vocab.stoi[tok] for tok in sentence]

                    sentence = Variable(torch.LongTensor([indexed]))
                    if opt.device == 0:
                        sentence = sentence.cuda()

                    try:
                        sentence = beam_search(sentence, model, SRC, TRG, opt)
                    except Exception as e:
                        continue
                    try:
                        val_acc += simple_em(sentence, e_tgt)
                        val_success += 1
                    except Exception as e:
                        continue

            if val_success == 0:
                val_success = 1
            val_acc = val_acc / val_success
            print('epoch', epoch, '- val accuracy:', round(val_acc * 100, 2))
            print()
            opt.scheduler.step(val_acc)

        if epoch == opt.epochs - 1 and opt.do_test:
            model.eval()
            test_data = zip_io_data(opt.data_path + '/test')
            test_predictions = ''
            test_acc, test_success = 0, 0
            for j, e in enumerate(test_data[:opt.n_test]):
                if (j + 1) % 10000 == 0:
                    print(round(j / len(test_data) * 100, 2), '% complete with testing')
                e_src, e_tgt = e[0], e[1]

                if opt.compositional_eval:
                    controller = eval_split_input(e_src)
                    intermediates = []
                    comp_failure = False
                    for controller_input in controller:
                        if len(controller_input) == 1:
                            controller_src = controller_input[0]

                        else:
                            controller_src = ''
                            for src_index in range(len(controller_input) - 1):
                                controller_src += intermediates[controller_input[src_index]] + ' @@SEP@@ '
                            controller_src += controller_input[-1]
                            controller_src = remove_whitespace(controller_src)

                        indexed = []
                        sentence = SRC.preprocess(controller_src)
                        for tok in sentence:
                            if SRC.vocab.stoi[tok] != 0:
                                indexed.append(SRC.vocab.stoi[tok])
                            else:
                                comp_failure = True
                                break
                        if comp_failure:
                            break

                        sentence = Variable(torch.LongTensor([indexed]))
                        if opt.device == 0:
                            sentence = sentence.cuda()

                        try:
                            sentence = beam_search(sentence, model, SRC, TRG, opt)
                            intermediates.append(sentence)
                        except Exception as e:
                            comp_failure = True
                            break

                    if not comp_failure:
                        try:
                            test_acc += simple_em(sentence, e_tgt)
                            test_success += 1
                            test_predictions += sentence + '\n'
                        except Exception as e:
                            test_predictions += '\n'
                            continue
                    else:
                        test_predictions += '\n'
                else:
                    indexed = []
                    sentence = SRC.preprocess(e_src)
                    pass_bool = False
                    for tok in sentence:
                        if SRC.vocab.stoi[tok] != 0:
                            indexed.append(SRC.vocab.stoi[tok])
                        else:
                            pass_bool = True
                            break
                    if pass_bool:
                        continue

                    sentence = Variable(torch.LongTensor([indexed]))
                    if opt.device == 0:
                        sentence = sentence.cuda()

                    try:
                        sentence = beam_search(sentence, model, SRC, TRG, opt)
                    except Exception as e:
                        continue
                    try:
                        test_acc += simple_em(sentence, e_tgt)
                        test_success += 1
                        test_predictions += sentence + '\n'
                    except Exception as e:
                        test_predictions += '\n'
                        continue

            if test_success == 0:
                test_success = 1
            test_acc = test_acc / test_success
            print('test accuracy:', round(test_acc * 100, 2))
            print()

            if not os.path.exists(opt.output_dir):
                os.makedirs(opt.output_dir)

            with open(opt.output_dir + '/test_generations.txt', 'w', encoding='utf-8') as f:
                f.write(test_predictions)
Example #25
0

if __name__ == '__main__':
    # encoder = CNN_Encoder(256)
    # decoder = RNN_Decoder(256, 512, 5000)
    out_encoder_feature = tf.random.uniform((64, 64, 256))
    num_layers = 4
    d_model = 512
    num_heads = 8
    dff = 1024
    vocab_size = 5000
    maximum_position_encoding = 200
    # tar_inp = tar[:, :-1]
    # tar_real = tar[:, 1:]
    decoder = ImageCaptioningTransformer(num_layers, d_model, num_heads, dff,
                                         maximum_position_encoding, vocab_size)

    temp_target = tf.random.uniform((64, 1), maxval=vocab_size, dtype=tf.int32)
    x = tf.expand_dims([0] * temp_target.shape[0], 1)
    # print(temp_target.shape,x.shape)
    enc_padding_mask, combined_mask, dec_padding_mask = create_masks(
        x, temp_target)

    output = decoder(out_encoder_feature,
                     temp_target,
                     training=True,
                     enc_padding_mask=enc_padding_mask,
                     look_ahead_mask=combined_mask,
                     dec_padding_mask=dec_padding_mask)
    print(output.shape)