def test(args): assert os.path.isfile(os.path.join(args.init_from,"config.pkl")), "config.pkl file does not exist in path %s" % args.init_from # open old config and check if models are compatible with open(os.path.join(args.init_from, 'config.pkl'), 'rb') as f: saved_args = cPickle.load(f) # complete arguments to fulfill different versions if("attention" in vars(saved_args)): print("attention: %d" % vars(saved_args)["attention"]) else: vars(saved_args)["attention"] = 0 if("schedule_sampling" in vars(saved_args)): print("schedule_sampling: %d" % vars(saved_args)["schedule_sampling"]) else: vars(saved_args)["schedule_sampling"] = 0.0 with open(os.path.join(args.init_from, 'vocab.pkl'), 'rb') as f: vocab = cPickle.load(f) vocab_inv = {v:k for k, v in vocab.items()} with open(args.testing_file,'r') as f: test_feat_id = f.readlines() for i in range(len(test_feat_id)): test_feat_id[i] = test_feat_id[i].replace('\n','') model = Video_Caption_Generator(saved_args,n_vocab=len(vocab),infer=True) with tf.Session() as sess: result = [] for i in range(len(test_feat_id)): tf.global_variables_initializer().run() saver = tf.train.Saver() ckpt = tf.train.get_checkpoint_state(args.init_from) if ckpt and ckpt.model_checkpoint_path: # args.init_from is not None: saver.restore(sess, ckpt.model_checkpoint_path) if i == 0: print("Model restored %s" % ckpt.model_checkpoint_path) sess.run(tf.global_variables()) # if i ==0: print("Initialized") this_test_feat_id = test_feat_id[i] # get vdieo features # notes: the second argument to get_video_feat must be np.array current_feat, current_feat_mask = get_video_feat(args.testing_path, np.array([this_test_feat_id])) this_gen_idx, probs = sess.run([model.gen_caption_idx,model.pred_probs],feed_dict={ model.video: current_feat, model.video_mask : current_feat_mask }) this_gen_words = [] for k in range(len(this_gen_idx)): this_gen_words.append(vocab_inv.get(this_gen_idx[k],'<PAD>')) this_gen_words = np.array(this_gen_words) punctuation = np.argmax(this_gen_words == '<EOS>') + 1 if punctuation > 1: this_gen_words = this_gen_words[:punctuation] this_caption = ' '.join(this_gen_words) this_caption = this_caption.replace('<BOS> ', '') this_caption = this_caption.replace(' <EOS>', '') this_answer = {} this_answer['caption'] = this_caption this_answer['id'] = this_test_feat_id print('Id: %s, caption: %s' % (this_test_feat_id, this_caption)) result.append(this_answer) with open(args.result_file, 'w') as fout: json.dump(result, fout)
def train(args): if args.init_from is not None: # check if all necessary files exist assert os.path.isfile( os.path.join(args.init_from, "config.pkl") ), "config.pkl file does not exist in path %s" % args.init_from # get ckpt ckpt = tf.train.get_checkpoint_state(args.init_from) # get vocab with open(os.path.join(args.init_from, 'vocab.pkl'), 'rb') as f: vocab = cPickle.load(f) vocab_inv = {v: k for k, v in vocab.items()} # read data _, _, train_feat_id, train_caption, test_feat_id, test_caption = data_preprocess( args.train_label_json, args.test_label_json) # open old config and check if models are compatible with open(os.path.join(args.init_from, 'config.pkl'), 'rb') as f: saved_args = cPickle.load(f) need_be_same = [ "dim_image", "dim_hidden", "n_lstm_step", "n_video_step", "n_caption_step" ] for checkme in need_be_same: assert vars(saved_args)[checkme] == vars( args )[checkme], "Command line argument and saved model disagree on '%s' " % checkme # complete arguments to fulfill different versions if ("schedule_sampling" in vars(saved_args)): print("schedule_sampling: %d" % vars(saved_args)["schedule_sampling"]) else: vars(saved_args)["schedule_sampling"] = 0.0 else: with open(os.path.join(args.save_dir, 'config.pkl'), 'wb') as f: cPickle.dump(args, f) vocab, vocab_inv, train_feat_id, train_caption, test_feat_id, test_caption = data_preprocess( args.train_label_json, args.test_label_json) with open(os.path.join(args.save_dir, 'vocab.pkl'), 'wb') as f: cPickle.dump(vocab, f) model = Video_Caption_Generator(args, n_vocab=len(vocab), infer=False) # add gpu options gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_mem) with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess: tf.global_variables_initializer().run() print("Initialized") saver = tf.train.Saver(tf.global_variables()) if args.init_from is not None: saver.restore(sess, ckpt.model_checkpoint_path) loss_fd = open('log/loss.txt', 'w') loss_to_draw = [] for epoch in range(0, args.n_epoch): if (model.schedule_sampling > 0.0): # [pseudo] prob of schedule sampling linearly increases with epochs model.schedule_sampling = np.min( [model.schedule_sampling * (1.0 + epoch / 50), 1.0]) # shuffle index = np.array(range(len(train_feat_id))) np.random.shuffle(index) epoch_train_feat_id = train_feat_id[index] epoch_train_caption = train_caption[index] loss_to_draw_epoch = [] for start, end in zip( range(0, len(epoch_train_feat_id), model.batch_size), range(model.batch_size, len(epoch_train_feat_id), model.batch_size)): # for start,end in zip(range(0,2,2),range(2,4,2)): start_time = time.time() # get one minibatch batch_feat_id = epoch_train_feat_id[start:end] batch_caption = epoch_train_caption[start:end] # get vdieo features current_feat, current_feat_mask = get_video_feat( args.train_video_feat_path, batch_feat_id) # randomly select one captions for one video and get padding captions with maxlen = 20 current_caption, current_caption_mask = get_padding_caption( vocab, batch_caption, maxlen=model.n_caption_step + 1) # run train_op to optimizer tf_loss _, loss_val = sess.run( [model.train_op, model.tf_loss], feed_dict={ model.video: current_feat, model.video_mask: current_feat_mask, model.caption: current_caption, model.caption_mask: current_caption_mask }) loss_to_draw_epoch.append(loss_val) print('idx: ', start, " Epoch: ", epoch, " loss: ", loss_val, ' Elapsed time: ', str((time.time() - start_time))) loss_fd.write('epoch ' + str(epoch) + ' loss ' + str(loss_val) + '\n') if np.mod(epoch, args.save_every) == 0: checkpoint_path = os.path.join(args.save_dir, 'model.ckpt') saver.save(sess, checkpoint_path, global_step=epoch) print("Epoch ", epoch, "model saved to {}".format(checkpoint_path)) loss_fd.close()