def train(args): if not os.path.exists(args.model_path): os.mkdir(args.model_path) writer = SummaryWriter("log") torch.cuda.set_device(args.device_id) model = CrossModal(vocab_size=args.vocab_size, pretrain_path=args.pretrain_path).cuda() #model = torch.nn.DataParallel(model).cuda() criterion = RankLoss() optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) step = 0 for epoch in range(args.epochs): train_reader = DataReader(args.vocab_path, args.train_data_path, args.image_path, args.vocab_size, args.batch_size, is_shuffle=True) print("train reader load succ......") for train_batch in train_reader.batch_generator(): query = torch.from_numpy(train_batch[0]).cuda() pos = torch.stack(train_batch[1], 0).cuda() neg = torch.stack(train_batch[2], 0).cuda() optimizer.zero_grad() left, right = model(query, pos, neg) loss = criterion(left, right).cuda() loss.backward() optimizer.step() if step == 0: writer.add_graph(model, (query, pos, neg)) if step % 100 == 0: writer.add_scalar('Train/Loss', loss.item(), step) if step % args.eval_interval == 0: print('Epoch [{}/{}], Step [{}] Loss: {:.4f}'.format(epoch + 1, args.epochs, step, loss.item()), flush=True) if step % args.save_interval == 0: # Save the model checkpoint torch.save(model.state_dict(), '%s/model.ckpt' % args.model_path) step += 1
def train(args): if not os.path.exists(args.model_path): os.mkdir(args.model_path) #tf.reset_default_graph() model = CrossModel(vocab_size=args.vocab_size) # optimizer train_step = tf.contrib.opt.LazyAdamOptimizer( learning_rate=args.learning_rate).minimize(model.loss) saver = tf.train.Saver() loss_summary = tf.summary.scalar("train_loss", model.loss) init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) with tf.Session() as sess: sess.run(init) #variables_to_restore = slim.get_variables_to_restore() #restore_fn = slim.assign_from_checkpoint_fn(args.pretrain_path, variables_to_restore) #restore_fn(sess) #sess.run(tf.global_variables_initializer()) init_variables_from_checkpoint(args.pretrain_path) _writer = tf.summary.FileWriter(args.logdir, sess.graph) # init embedding embedding = load_embedding(args.emb_path, args.vocab_size, 256) _ = sess.run(model.embedding_init, feed_dict={model.embedding_in: embedding}) print("loading pretrain emb succ.") # summary summary_op = tf.summary.merge([loss_summary]) step = 0 for epoch in range(args.epochs): train_reader = DataReader(args.vocab_path, args.train_data_path, args.image_data_path, args.vocab_size, args.batch_size, is_shuffle=True) print("train reader load succ.") for train_batch in train_reader.batch_generator(): query, pos, neg = train_batch _, _loss, _summary = sess.run( [train_step, model.loss, summary_op], feed_dict={ model.text: query, model.img_pos: pos, model.img_neg: neg }) _writer.add_summary(_summary, step) step += 1 # test sum_loss = 0.0 iters = 0 summary = tf.Summary() if step % args.eval_interval == 0: print("Epochs: {}, Step: {}, Train Loss: {:.4}".format( epoch, step, _loss)) test_reader = DataReader(args.vocab_path, args.test_data_path, args.image_data_path, args.vocab_size, args.batch_size) for test_batch in test_reader.batch_generator(): query, pos, neg = test_batch _loss = sess.run(model.loss, feed_dict={ model.text: query, model.img_pos: pos, model.img_neg: neg }) sum_loss += _loss iters += 1 avg_loss = sum_loss / iters summary.value.add(tag="test_loss", simple_value=avg_loss) _writer.add_summary(summary, step) print("Epochs: {}, Step: {}, Test Loss: {:.4}".format( epoch, step, sum_loss / iters)) if step % args.save_interval == 0: save_path = saver.save(sess, "{}/model.ckpt".format( args.model_path), global_step=step) print("Model save to path: {}/model.ckpt".format( args.model_path))
def train(args): if not os.path.exists(args.model_path): os.mkdir(args.model_path) tf.reset_default_graph() model = TextClassification(vocab_size=args.vocab_size, encoder_type=args.encoder_type, max_seq_len=args.max_seq_len) # optimizer train_step = tf.contrib.opt.LazyAdamOptimizer(learning_rate=args.learning_rate).minimize(model.loss) saver = tf.train.Saver() loss_summary = tf.summary.scalar("train_loss", model.loss) init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) with tf.Session() as sess: sess.run(init) # feeding embedding _writer = tf.summary.FileWriter(args.logdir, sess.graph) # summary summary_op = tf.summary.merge([loss_summary]) step = 0 for epoch in range(args.epochs): train_reader = DataReader(args.vocab_path, args.train_data_path, args.vocab_size, args.batch_size, args.max_seq_len) for train_batch in train_reader.batch_generator(): text, label = train_batch _, _loss, _summary, _logits = sess.run([train_step, model.loss, summary_op, model.logits], feed_dict={model.label_in: label, model.text_in: text}) _writer.add_summary(_summary, step) step += 1 # test summary = tf.Summary() if step % args.eval_interval == 0: acc, acc_op = tf.metrics.accuracy(labels=tf.argmax(label, 1), predictions=tf.argmax(_logits, 1)) sess.run(tf.local_variables_initializer()) _, _acc = sess.run([acc, acc_op]) summary.value.add(tag="train_accuracy", simple_value=_acc) print("Epochs: {}, Step: {}, Train Loss: {}, Acc: {}".format(epoch, step, _loss, _acc)) test_reader = DataReader(args.vocab_path, args.test_data_path, args.vocab_size, args.batch_size, args.max_seq_len) sum_loss = 0.0 sum_acc = 0.0 iters = 0 for test_batch in test_reader.batch_generator(): text, label = test_batch _loss, _logits = sess.run([model.loss, model.logits], feed_dict={model.label_in: label, model.text_in: text}) acc, acc_op = tf.metrics.accuracy(labels=tf.argmax(label, 1), predictions=tf.argmax(_logits, 1)) sess.run(tf.local_variables_initializer()) _, _acc = sess.run([acc, acc_op]) sum_acc += _acc sum_loss += _loss iters += 1 avg_loss = sum_loss / iters avg_acc = sum_acc / iters summary.value.add(tag="test_accuracy", simple_value=avg_acc) summary.value.add(tag="test_loss", simple_value=avg_loss) _writer.add_summary(summary, step) print("Epochs: {}, Step: {}, Test Loss: {}, Acc: {}".format(epoch, step, avg_loss, avg_acc)) if step % args.save_interval == 0: save_path = saver.save(sess, "{}/birnn.lm.ckpt".format(args.model_path), global_step=step) print("Model save to path: {}/birnn.lm.ckpt".format(args.model_path))