예제 #1
0
def train(args):
    if not os.path.exists(args.model_path):
        os.mkdir(args.model_path)

    writer = SummaryWriter("log")
    torch.cuda.set_device(args.device_id)

    model = CrossModal(vocab_size=args.vocab_size, 
            pretrain_path=args.pretrain_path).cuda()
    #model = torch.nn.DataParallel(model).cuda()
    criterion = RankLoss()
    optimizer = torch.optim.Adam(model.parameters(), 
            lr=args.learning_rate)

    step = 0
    for epoch in range(args.epochs):
        train_reader = DataReader(args.vocab_path, args.train_data_path, args.image_path, 
                args.vocab_size, args.batch_size, is_shuffle=True)
        print("train reader load succ......")
        for train_batch in train_reader.batch_generator():
            query = torch.from_numpy(train_batch[0]).cuda()
            pos = torch.stack(train_batch[1], 0).cuda()
            neg = torch.stack(train_batch[2], 0).cuda()

            optimizer.zero_grad()
        
            left, right = model(query, pos, neg)
            loss = criterion(left, right).cuda()

            loss.backward()
            optimizer.step()
            if step == 0:
                writer.add_graph(model, (query, pos, neg))

            if step % 100 == 0:
                writer.add_scalar('Train/Loss', loss.item(), step)    

            if step % args.eval_interval == 0:
                print('Epoch [{}/{}], Step [{}] Loss: {:.4f}'.format(epoch + 1, 
                            args.epochs, step, loss.item()), flush=True)

            if step % args.save_interval == 0:
                # Save the model checkpoint
                torch.save(model.state_dict(), '%s/model.ckpt' % args.model_path)
            step += 1
예제 #2
0
def train(args):
    if not os.path.exists(args.model_path):
        os.mkdir(args.model_path)
    #tf.reset_default_graph()
    model = CrossModel(vocab_size=args.vocab_size)
    # optimizer
    train_step = tf.contrib.opt.LazyAdamOptimizer(
        learning_rate=args.learning_rate).minimize(model.loss)
    saver = tf.train.Saver()
    loss_summary = tf.summary.scalar("train_loss", model.loss)
    init = tf.group(tf.global_variables_initializer(),
                    tf.local_variables_initializer())
    with tf.Session() as sess:
        sess.run(init)
        #variables_to_restore = slim.get_variables_to_restore()
        #restore_fn = slim.assign_from_checkpoint_fn(args.pretrain_path, variables_to_restore)
        #restore_fn(sess)
        #sess.run(tf.global_variables_initializer())
        init_variables_from_checkpoint(args.pretrain_path)

        _writer = tf.summary.FileWriter(args.logdir, sess.graph)
        # init embedding
        embedding = load_embedding(args.emb_path, args.vocab_size, 256)
        _ = sess.run(model.embedding_init,
                     feed_dict={model.embedding_in: embedding})
        print("loading pretrain emb succ.")

        # summary
        summary_op = tf.summary.merge([loss_summary])
        step = 0
        for epoch in range(args.epochs):
            train_reader = DataReader(args.vocab_path,
                                      args.train_data_path,
                                      args.image_data_path,
                                      args.vocab_size,
                                      args.batch_size,
                                      is_shuffle=True)
            print("train reader load succ.")
            for train_batch in train_reader.batch_generator():
                query, pos, neg = train_batch

                _, _loss, _summary = sess.run(
                    [train_step, model.loss, summary_op],
                    feed_dict={
                        model.text: query,
                        model.img_pos: pos,
                        model.img_neg: neg
                    })
                _writer.add_summary(_summary, step)
                step += 1

                # test
                sum_loss = 0.0
                iters = 0
                summary = tf.Summary()
                if step % args.eval_interval == 0:
                    print("Epochs: {}, Step: {}, Train Loss: {:.4}".format(
                        epoch, step, _loss))

                    test_reader = DataReader(args.vocab_path,
                                             args.test_data_path,
                                             args.image_data_path,
                                             args.vocab_size, args.batch_size)
                    for test_batch in test_reader.batch_generator():
                        query, pos, neg = test_batch
                        _loss = sess.run(model.loss,
                                         feed_dict={
                                             model.text: query,
                                             model.img_pos: pos,
                                             model.img_neg: neg
                                         })
                        sum_loss += _loss
                        iters += 1
                    avg_loss = sum_loss / iters
                    summary.value.add(tag="test_loss", simple_value=avg_loss)
                    _writer.add_summary(summary, step)
                    print("Epochs: {}, Step: {}, Test Loss: {:.4}".format(
                        epoch, step, sum_loss / iters))
                if step % args.save_interval == 0:
                    save_path = saver.save(sess,
                                           "{}/model.ckpt".format(
                                               args.model_path),
                                           global_step=step)
                    print("Model save to path: {}/model.ckpt".format(
                        args.model_path))
예제 #3
0
def train(args):
    if not os.path.exists(args.model_path):
        os.mkdir(args.model_path)
    tf.reset_default_graph()
    model = TextClassification(vocab_size=args.vocab_size, 
            encoder_type=args.encoder_type, max_seq_len=args.max_seq_len)
    # optimizer
    train_step = tf.contrib.opt.LazyAdamOptimizer(learning_rate=args.learning_rate).minimize(model.loss)
    saver = tf.train.Saver()
    loss_summary = tf.summary.scalar("train_loss", model.loss)
    init = tf.group(tf.global_variables_initializer(), 
            tf.local_variables_initializer())
    with tf.Session() as sess:
        sess.run(init)
        # feeding embedding
        _writer = tf.summary.FileWriter(args.logdir, sess.graph)

        # summary
        summary_op = tf.summary.merge([loss_summary])
        step = 0
        for epoch in range(args.epochs):
            train_reader = DataReader(args.vocab_path, args.train_data_path, 
                    args.vocab_size, args.batch_size, args.max_seq_len)
            for train_batch in train_reader.batch_generator():
                text, label = train_batch
                _, _loss, _summary, _logits = sess.run([train_step, model.loss, summary_op, model.logits],
                        feed_dict={model.label_in: label, model.text_in: text})
                _writer.add_summary(_summary, step)
                step += 1


                # test
                summary = tf.Summary()
                if step % args.eval_interval == 0:
                    acc, acc_op = tf.metrics.accuracy(labels=tf.argmax(label, 1),
                            predictions=tf.argmax(_logits, 1))
                    sess.run(tf.local_variables_initializer())
                    _, _acc = sess.run([acc, acc_op])
                    summary.value.add(tag="train_accuracy", simple_value=_acc)
                    print("Epochs: {}, Step: {}, Train Loss: {}, Acc: {}".format(epoch, step, _loss, _acc))

                    test_reader = DataReader(args.vocab_path, args.test_data_path, 
                            args.vocab_size, args.batch_size, args.max_seq_len)
                    sum_loss = 0.0
                    sum_acc = 0.0
                    iters = 0
                    for test_batch in test_reader.batch_generator():
                        text, label = test_batch
                        _loss, _logits = sess.run([model.loss, model.logits],
                                feed_dict={model.label_in: label, model.text_in: text})
                        acc, acc_op = tf.metrics.accuracy(labels=tf.argmax(label, 1),
                                predictions=tf.argmax(_logits, 1))
                        sess.run(tf.local_variables_initializer())
                        _, _acc = sess.run([acc, acc_op])
                        sum_acc += _acc
                        sum_loss += _loss
                        iters += 1
                    avg_loss = sum_loss / iters
                    avg_acc = sum_acc / iters
                    summary.value.add(tag="test_accuracy", simple_value=avg_acc)
                    summary.value.add(tag="test_loss", simple_value=avg_loss)
                    _writer.add_summary(summary, step)
                    print("Epochs: {}, Step: {}, Test Loss: {}, Acc: {}".format(epoch, step, avg_loss, avg_acc))
                if step % args.save_interval == 0:
                    save_path = saver.save(sess, "{}/birnn.lm.ckpt".format(args.model_path), global_step=step)
                    print("Model save to path: {}/birnn.lm.ckpt".format(args.model_path))