Exemplo n.º 1
0
def main():
    args = get_args()
    src_vocab_size = len(
        pickle.load(open(args.data_bin + '/dict' + "." + args.src_lang,
                         'rb')).keys())
    tgt_vocab_size = len(
        pickle.load(open(args.data_bin + '/dict' + '.' + args.tgt_lang,
                         'rb')).keys())
    device = 'cuda'
    model = Transformer(src_vocab_size=src_vocab_size,
                        tgt_vocab_size=tgt_vocab_size,
                        encoder_layer_num=args.encoder_layer_num,
                        decoder_layer_num=args.decoder_layer_num,
                        hidden_size=args.hidden_size,
                        feedback_size=args.feedback,
                        num_head=args.num_head,
                        dropout=args.dropout,
                        device=device)
    optim = Optim(Adam(model.parameters(), betas=(0.9, 0.98), eps=1e-9),
                  warmup_step=4000,
                  d_model=args.hidden_size)
    train_loader = DataLoader(Dataload(args.data_bin + '/' + 'train',
                                       args.src_lang, args.tgt_lang),
                              batch_size=args.batch_size,
                              collate_fn=collate_fn,
                              shuffle=True)
    # optim = Adam(model.parameters(), lr=5e-6)
    test_loader = DataLoader(Dataload(args.data_bin + '/' + 'test',
                                      args.src_lang, args.tgt_lang),
                             batch_size=args.batch_size,
                             collate_fn=collate_fn)
    valid_loader = DataLoader(Dataload(args.data_bin + '/' + 'valid',
                                       args.src_lang, args.tgt_lang),
                              batch_size=args.batch_size,
                              collate_fn=collate_fn)
    best_loss = 1e4
    model = model.to(device)
    # model.load_state_dict(torch.load('best_model.pkl'))
    for i in range(args.epoch):
        train(i, model, data_loader=train_loader, optim=optim, device=device)
        with torch.no_grad():
            best_loss = eval(i, model, valid_loader, best_loss, device)
Exemplo n.º 2
0
def train_phoneme(num_layers, lr, batch_size, hidden,
                  numepoch, dropout, inputfeeding, cuda, maxlen, soft=True):
    dataset = Dataload(maxlen=maxlen)
    criterion = nn.NLLLoss(ignore_index=0, reduction='sum')
    model = full_model.make_model(dataset.src_num, dataset.trg_num, emb_size=32,
                                  hidden_size=hidden, num_layers=num_layers, dropout=dropout,
                                  inputfeeding=inputfeeding, soft=soft)
    optim = torch.optim.Adam(model.parameters(), lr=lr)
    eval_data = list(dataset.data_gen(batch_size=1, num_batches=100, eval=True))
    dev_perplexities = []

    if init.USE_CUDA and cuda:
        model.cuda()
    min_perplexity = 100
    minp_iter = 0
    max_accuracy = 0
    maxa_iter = 0
    for epoch in range(numepoch):
        print("Epoch %d" % epoch)

        model.train()
        data = dataset.data_gen(batch_size=batch_size, num_batches=100, eval=True)
        train.run_epoch(data, model,
                  train.SimpleLossCompute(model.generator, criterion, optim))
        model.eval()
        with torch.no_grad():
            perplexity, accuracy = train.run_epoch(eval_data, model,
                                   train.SimpleLossCompute(model.generator, criterion, None))
            if perplexity < min_perplexity:
                min_perplexity = perplexity
                minp_iter = epoch
            if accuracy > max_accuracy:
                max_accuracy = accuracy
                maxa_iter = epoch
            print("Evaluation perplexity: %f" % perplexity)
            print("Evaluation accuracy: %f" % accuracy)
            dev_perplexities.append(perplexity)
            print_e.print_examples(eval_data, dataset, model, n=2, max_len=maxlen)
    print("min perplexity: %f at %d iterations" % (min_perplexity, minp_iter))
    print("max accuracy: %f at %d iterations" % (max_accuracy, maxa_iter))
    return dev_perplexities
Exemplo n.º 3
0
total = 0
y_true = []
y_pre = []
for leave_out_sub in all_sub_list:
    print('==> Leaving out ' + leave_out_sub)

    net = KFC_MER_Model()
    model_path = os.path.join(model_dir, leave_out_sub, 'Resnest18_2blocks_OFOS_map_49.pth')
    checkpoint = torch.load(model_path)
    net.load_state_dict(checkpoint['net'])
    net.eval()

    if use_cuda:
        net.cuda()

    test_set = Dataload(split='Testing', transform=transform_test, leave_out=leave_out_sub)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=4, shuffle=False, num_workers=1)

    for batch_idx, (inputs, targets) in enumerate(test_loader):
        inputs1 = inputs[:, :3, :, :]
        inputs2 = inputs[:, 3:6, :, :]
        seg_map = inputs[:, 6:, :, :]

        if use_cuda:
            inputs1, inputs2, seg_map, targets = inputs1.cuda(), inputs2.cuda(), seg_map.cuda(), targets.cuda()
        with torch.no_grad():
            inputs1, inputs2, seg_map, targets = Variable(inputs1), Variable(inputs2), Variable(seg_map), Variable(targets)

        outputs = net(inputs1, seg_map, inputs2, seg_map)

        _, predicted = torch.max(outputs.data, 1)
Exemplo n.º 4
0
    open(add_(args.bin_path) + 'dict.' + args.src_lang, 'rb'))
trg_dict = pickle.load(
    open(add_(args.bin_path) + 'dict.' + args.tgt_lang, 'rb'))
model = Transformer(src_vocab_size=len(src_dict.keys()),
                    tgt_vocab_size=len(trg_dict.keys()),
                    encoder_layer_num=6,
                    decoder_layer_num=6,
                    hidden_size=512,
                    feedback_size=2048,
                    num_head=8,
                    dropout=0.1,
                    device=device)
model = model.to(device)
model.load_state_dict(torch.load(args.model_path))
dataload = DataLoader(Dataload(add_(args.bin_path) + 'test',
                               src=args.src_lang,
                               trg=args.tgt_lang),
                      batch_size=32,
                      collate_fn=collate_fn)
real = []
predict = []
pbtr = tqdm(total=len(dataload))
with torch.no_grad():
    model.eval()
    for src, trg in dataload:
        src = src.to(device)
        predicts = beamsearch(model, src, 1, 100, device=device)

        for i in range(len(predicts)):
            while 0 in predicts[i]:
                predicts[i].remove(0)
    def train(self):

        with tf.name_scope('loss'):
            loss = tf.nn.ctc_loss(self._label, self._net_output, self._seq_len)
            loss = tf.reduce_mean(loss)
            tf.summary.scalar("loss", loss)

        with tf.name_scope('optimizer'):
            train_op = tf.train.AdamOptimizer(
                self._learning_rate).minimize(loss)

        with tf.name_scope('accuracy'):
            accuracy = 1 - tf.reduce_mean(
                tf.edit_distance(tf.cast(self._decoded[0], tf.int32),
                                 self._label))
            accuracy_broad = tf.summary.scalar("accuracy", accuracy)

        data = Dataload(self.batch_size,
                        './data/dataset_label.txt',
                        img_height=self.input_height,
                        img_width=self.input_width)

        # 保存模型
        saver = tf.train.Saver()

        # tensorboard
        merged = tf.summary.merge_all()

        with tf.Session() as sess:
            if self._pre_train:
                saver.restore(sess, self._model_save_path)
                print('load model from:', self._model_save_path)
            else:
                sess.run(tf.global_variables_initializer())

            train_writer = tf.summary.FileWriter("./tensorboard_logs/",
                                                 sess.graph)

            epoch = data.epoch
            for step in range(self._start_step + 1, self._max_iterators):
                batch_data, batch_label = data.get_train_batch()

                feed_dict = {
                    self._inputs: batch_data,
                    self._label: batch_label,
                    self._seq_len: [self._max_char_count] * self.batch_size
                }

                summ = sess.run(merged, feed_dict=feed_dict)
                train_writer.add_summary(summ, global_step=step)

                sess.run(train_op, feed_dict=feed_dict)

                if step % 20 == 0:
                    train_loss = sess.run(loss, feed_dict=feed_dict)
                    self.train_logger.info('step:%d, total loss: %6f' %
                                           (step, train_loss))
                    self.train_logger.info('compute accuracy...')
                    train_accuracy = sess.run(accuracy, feed_dict=feed_dict)
                    val_data, val_label = data.get_val_batch(self.batch_size)
                    val_accuracy = sess.run(
                        accuracy,
                        feed_dict={
                            self._inputs: val_data,
                            self._label: val_label,
                            self._seq_len:
                            [self._max_char_count] * self.batch_size
                        })

                    self.train_logger.info('epoch:%d, train accuracy: %6f' %
                                           (epoch, train_accuracy))
                    self.train_logger.info('epoch:%d, val accuracy: %6f' %
                                           (epoch, val_accuracy))
                    # 用于验证网络的输出是否正确
                    # if train_accuracy>0.9:
                    #     print('label:', batch_label)
                    #     print('predict:', sess.run(self.dense_decoded, feed_dict=feed_dict))

                # if step%10 == 0:
                #     train_accuracy = sess.run(accuracy, feed_dict=feed_dict)
                #     self.train_logger.info('step:%d, train accuracy: %6f' % (epoch, train_accuracy))

                if step % 100 == 0:

                    self.train_logger.info('saving model...')
                    f = open('./model/train_step.txt', 'w')
                    f.write(str(self._start_step + step))
                    f.close()
                    save_path = saver.save(sess, self._model_save_path)
                    self.train_logger.info('model saved at %s' % save_path)

                if epoch != data.epoch:
                    epoch = data.epoch
                    self.train_logger.info('compute accuracy...')
                    train_accuracy = sess.run(accuracy, feed_dict=feed_dict)
                    self.train_logger.info('epoch:%d, accuracy: %6f' %
                                           (epoch, train_accuracy))
                    summ = sess.run(accuracy_broad, feed_dict=feed_dict)
                    train_writer.add_summary(summ, global_step=step)

            train_writer.close()