Ejemplo n.º 1
0
def train(sub_dir, logging, model_save_dir, result_save_dir):

    if not os.path.exists(options['word_fts_path']):
        meta_data, train_data, test_data = get_video_data_jukin(
            options['video_data_path_train'], options['video_data_path_test'])
        captions = meta_data['Description'].values
        for c in string.punctuation:
            captions = map(lambda x: x.replace(c, ''), captions)
        wordtoix, ixtoword, bias_init_vector = preProBuildWordVocab(
            logging, captions, word_count_threshold=1)
        np.save(options['ixtoword_path'], ixtoword)
        np.save(options['wordtoix_path'], wordtoix)
        get_word_embedding(options['word_embedding_path'],
                           options['wordtoix_path'], options['ixtoword_path'],
                           options['word_fts_path'])
        word_emb_init = np.array(
            np.load(open(options['word_fts_path'])).tolist(), np.float32)
    else:
        wordtoix = (np.load(options['wordtoix_path'])).tolist()
        ixtoword = (np.load(options['ixtoword_path'])).tolist()
        word_emb_init = np.array(
            np.load(open(options['word_fts_path'])).tolist(), np.float32)
        train_data = get_video_data_HL(
            options['video_data_path_train'])  # get h5 file list

    if finetune:
        start_epoch = 150
        MODEL = model_save_dir + '/model-' + str(start_epoch - 1)

    model = SSAD_SCDM(options, word_emb_init)
    inputs, outputs = model.build_train()
    t_loss = outputs['loss_all']
    t_loss_ssad = outputs['loss_ssad']
    t_loss_regular = outputs['reg_loss']
    t_positive_loss_all = outputs['positive_loss_all']
    t_hard_negative_loss_all = outputs['hard_negative_loss_all']
    t_easy_negative_loss_all = outputs['easy_negative_loss_all']
    t_smooth_center_loss_all = outputs['smooth_center_loss_all']
    t_smooth_width_loss_all = outputs['smooth_width_loss_all']

    t_feature_segment = inputs['feature_segment']
    t_sentence_index_placeholder = inputs['sentence_index_placeholder']
    t_sentence_w_len = inputs['sentence_w_len']
    t_gt_overlap = inputs['gt_overlap']

    config = tf.ConfigProto(allow_soft_placement=True)
    config.gpu_options.per_process_gpu_memory_fraction = 0.3
    sess = tf.InteractiveSession(config=config)
    optimizer = optimizer_factory[options['optimizer']](
        **options['opt_arg'][options['optimizer']])
    if options['clip']:
        gvs = optimizer.compute_gradients(t_loss)
        capped_gvs = [(tf.clip_by_value(grad, -1., 1.), var)
                      for grad, var in gvs]
        train_op = optimizer.apply_gradients(capped_gvs)
    else:
        train_op = optimizer.minimize(t_loss)

    with tf.device("/cpu:0"):
        saver = tf.train.Saver(max_to_keep=200)
    tf.initialize_all_variables().run()

    with tf.device("/cpu:0"):
        if finetune:
            saver.restore(sess, MODEL)

    ############################################# start training ####################################################

    tStart_total = time.time()
    for epoch in range(options['max_epochs']):

        index = np.arange(len(train_data))
        np.random.shuffle(index)
        train_data = train_data[index]

        tStart_epoch = time.time()

        loss_list = np.zeros(
            len(train_data
                ))  # each item in loss_epoch record the loss of this h5 file
        loss_ssad_list = np.zeros(len(train_data))
        loss_positive_loss_all_list = np.zeros(len(train_data))
        loss_hard_negative_loss_all_list = np.zeros(len(train_data))
        loss_easy_negative_loss_all_list = np.zeros(len(train_data))
        loss_smooth_center_loss_all_list = np.zeros(len(train_data))
        loss_smooth_width_loss_all_list = np.zeros(len(train_data))

        for current_batch_file_idx in xrange(len(train_data)):

            logging.info(
                "current_batch_file_idx = {:d}".format(current_batch_file_idx))
            logging.info(train_data[current_batch_file_idx])

            tStart = time.time()
            current_batch = h5py.File(train_data[current_batch_file_idx], 'r')

            # processing sentence
            current_captions_tmp = current_batch['sentence']
            current_captions = []
            for ind in range(options['batch_size']):
                current_captions.append(current_captions_tmp[ind])
            current_captions = np.array(current_captions)
            for ind in range(options['batch_size']):
                for c in string.punctuation:
                    current_captions[ind] = current_captions[ind].replace(
                        c, '')
            for i in range(options['batch_size']):
                current_captions[i] = current_captions[i].strip()
                if current_captions[i] == '':
                    current_captions[i] = '.'
            current_caption_ind = map(
                lambda cap: [
                    wordtoix[word] for word in cap.lower().split(' ')
                    if word in wordtoix
                ], current_captions)
            current_caption_matrix = sequence.pad_sequences(
                current_caption_ind,
                padding='post',
                maxlen=options['max_sen_len'] - 1)
            current_caption_matrix = np.hstack([
                current_caption_matrix,
                np.zeros([len(current_caption_matrix), 1])
            ]).astype(int)
            current_caption_length = np.array(
                map(lambda x: (x != 0).sum(), current_caption_matrix)
            )  # save the sentence length of this batch

            # processing video
            current_video_feats = np.array(current_batch['video_source_fts'])
            current_anchor_input = np.array(current_batch['anchor_input'])
            current_ground_interval = np.array(
                current_batch['ground_interval'])
            current_video_name = current_batch['video_name']
            current_video_duration = np.array(current_batch['video_duration'])

            _,  loss, loss_ssad, positive_loss_all, hard_negative_loss_all, easy_negative_loss_all,\
                smooth_center_loss_all, smooth_width_loss_all, loss_regular = sess.run(
                    [train_op, t_loss, t_loss_ssad , t_positive_loss_all, t_hard_negative_loss_all, \
                     t_easy_negative_loss_all, t_smooth_center_loss_all, t_smooth_width_loss_all, t_loss_regular], \
                    feed_dict={
                        t_feature_segment: current_video_feats,
                        t_sentence_index_placeholder: current_caption_matrix,
                        t_sentence_w_len: current_caption_length,
                        t_gt_overlap: current_anchor_input
                        })

            loss_list[current_batch_file_idx] = loss
            loss_ssad_list[current_batch_file_idx] = loss_ssad
            loss_positive_loss_all_list[
                current_batch_file_idx] = positive_loss_all
            loss_hard_negative_loss_all_list[
                current_batch_file_idx] = hard_negative_loss_all
            loss_easy_negative_loss_all_list[
                current_batch_file_idx] = easy_negative_loss_all
            loss_smooth_center_loss_all_list[
                current_batch_file_idx] = smooth_center_loss_all
            loss_smooth_width_loss_all_list[
                current_batch_file_idx] = smooth_width_loss_all

            logging.info(
                "loss = {:f} loss_ssad = {:f} loss_regular = {:f} positive_loss_all = {:f} hard_negative_loss_all = {:f} easy_negative_loss_all = {:f} smooth_center_loss_all = {:f} smooth_width_loss_all = {:f}"
                .format(loss, loss_ssad, loss_regular, positive_loss_all,
                        hard_negative_loss_all, easy_negative_loss_all,
                        smooth_center_loss_all, smooth_width_loss_all))

        if finetune:
            logging.info("Epoch: {:d} done.".format(epoch + start_epoch))
        else:
            logging.info("Epoch: {:d} done.".format(epoch))
        tStop_epoch = time.time()
        logging.info('Epoch Time Cost: {:f} s'.format(
            round(tStop_epoch - tStart_epoch, 2)))

        logging.info('Current Epoch Mean loss {:f}'.format(np.mean(loss_list)))
        logging.info('Current Epoch Mean loss_ssad {:f}'.format(
            np.mean(loss_ssad_list)))
        logging.info('Current Epoch Mean positive_loss_all {:f}'.format(
            np.mean(loss_positive_loss_all_list)))
        logging.info('Current Epoch Mean hard_negative_loss_all {:f}'.format(
            np.mean(loss_hard_negative_loss_all_list)))
        logging.info('Current Epoch Mean easy_negative_loss_all {:f}'.format(
            np.mean(loss_easy_negative_loss_all_list)))
        logging.info('Current Epoch Mean smooth_center_loss_all {:f}'.format(
            np.mean(loss_smooth_center_loss_all_list)))
        logging.info('Current Epoch Mean smooth_width_loss_all {:f}'.format(
            np.mean(loss_smooth_width_loss_all_list)))

        #################################################### test ################################################################################################
        if np.mod(epoch, 1) == 0 and epoch >= 50:
            if finetune:
                logging.info('Epoch {:d} is done. Saving the model ...'.format(
                    epoch + start_epoch))
            else:
                logging.info(
                    'Epoch {:d} is done. Saving the model ...'.format(epoch))
            with tf.device("/cpu:0"):
                if finetune:
                    saver.save(sess,
                               os.path.join(model_save_dir, 'model'),
                               global_step=epoch + start_epoch)
                else:
                    saver.save(sess,
                               os.path.join(model_save_dir, 'model'),
                               global_step=epoch)

    logging.info("Finally, saving the model ...")
    with tf.device("/cpu:0"):
        if finetune:
            saver.save(sess,
                       os.path.join(model_save_dir, 'model'),
                       global_step=epoch + start_epoch)
        else:
            saver.save(sess,
                       os.path.join(model_save_dir, 'model'),
                       global_step=epoch)

    tStop_total = time.time()
    logging.info("Total Time Cost: {:f} s".format(
        round(tStop_total - tStart_total, 2)))
Ejemplo n.º 2
0
def train(logging, model_save_dir, result_save_dir):

    if not os.path.exists(options['word_fts_path']):
        meta_data, train_data, dev_data, test_data = get_video_data_jukin(
            options['video_data_path_train'], options['video_data_path_dev'],
            options['video_data_path_test'])
        captions = meta_data['Description'].values
        for c in string.punctuation:
            captions = map(lambda x: x.replace(c, ''), captions)
        wordtoix, ixtoword, bias_init_vector = preProBuildWordVocab(
            logging, captions, word_count_threshold=1)
        np.save(options['ixtoword_path'], ixtoword)
        np.save(options['wordtoix_path'], wordtoix)
        get_word_embedding(options['word_embedding_path'],
                           options['wordtoix_path'], options['ixtoword_path'],
                           options['word_fts_path'])
        word_emb_init = np.array(
            np.load(options['word_fts_path'],
                    encoding='bytes',
                    allow_pickle=True).tolist(), np.float32)
    else:
        wordtoix = (np.load(options['wordtoix_path'],
                            allow_pickle=True)).tolist()
        ixtoword = (np.load(options['ixtoword_path'],
                            allow_pickle=True)).tolist()
        word_emb_init = np.array(
            np.load(options['word_fts_path'],
                    encoding='bytes',
                    allow_pickle=True).tolist(), np.float32)
        train_data = get_video_data_HL(
            options['video_data_path_train'])  # get h5 file list

    if finetune:
        start_epoch = PRETRAINED_EPOCH
        MODEL = model_save_dir + '/model-' + str(start_epoch - 1)

    model = SSAD_SCDM(options, word_emb_init)
    inputs, outputs = model.build_train()
    t_loss = outputs['loss_all']
    t_loss_ssad = outputs['loss_ssad']
    t_loss_regular = outputs['reg_loss']
    t_positive_loss_all = outputs['positive_loss_all']
    t_hard_negative_loss_all = outputs['hard_negative_loss_all']
    t_easy_negative_loss_all = outputs['easy_negative_loss_all']
    t_smooth_center_loss_all = outputs['smooth_center_loss_all']
    t_smooth_width_loss_all = outputs['smooth_width_loss_all']
    t_facial_positive_loss_all = outputs['facial_positive_loss_all']
    t_facial_negative_loss_all = outputs['facial_negative_loss_all']
    t_facial_possitive_num = outputs['facial_possitive_num']
    t_facial_possitive_loss = outputs['facial_possitive_loss']
    t_facial_negative_num = outputs['facial_negative_num']
    t_facial_negative_loss = outputs['facial_negative_loss']
    t_predict_overlap = outputs['predict_overlap']
    t_predict_reg = outputs['predict_reg']
    t_predict_label_map = outputs['predict_facial_map']
    t_facial_map_accuracy_dict = outputs['facial_map_true']
    t_facial_map_precision_dict = outputs['facial_map_true_true']
    #print('t_facial_negative_loss_all.type:',type(t_facial_negative_loss_all))

    t_feature_segment = inputs['feature_segment']
    t_sentence_index_placeholder = inputs['sentence_index_placeholder']
    t_sentence_w_len = inputs['sentence_w_len']
    t_gt_overlap = inputs['gt_overlap']
    t_gt_facial_map = inputs['gt_facial_map']

    config = tf.ConfigProto(allow_soft_placement=True)
    config.gpu_options.per_process_gpu_memory_fraction = 0.9  # maximun alloc gpu90% of MEM
    config.gpu_options.allow_growth = True  #allocate dynamically
    sess = tf.InteractiveSession(config=config)

    optimizer = optimizer_factory[options['optimizer']](
        **options['opt_arg'][options['optimizer']])
    if options['clip']:
        gvs = optimizer.compute_gradients(t_loss)
        capped_gvs = [(tf.clip_by_value(grad, -1., 1.), var)
                      for grad, var in gvs]
        train_op = optimizer.apply_gradients(capped_gvs)
    else:
        train_op = optimizer.minimize(t_loss)

    with tf.device("/cpu:0"):
        saver = tf.train.Saver(max_to_keep=400)
    tf.initialize_all_variables().run()

    with tf.device("/cpu:0"):
        if finetune:
            saver.restore(sess, MODEL)
    ############################################# start training ####################################################
    loss_all_dict = {
        'mean_loss': [],
        'mean_ssad': [],
        'mean_positive_loss': [],
        'mean_hard_negative_loss': [],
        'mean_easy_negative_loss': [],
        'mean_smooth_center_loss': [],
        'mean_smooth_width_loss': [],
        'mean_facial_positive_loss': [],
        'mean_facial_negative_loss': []
    }
    tStart_total = time.time()

    for epoch in range(options['max_epochs']):
        facial_map_accuracy_dict_all = {}
        facial_map_precision_dict_all = {}
        facial_map_gt_dict = {}
        for key in options['facial_dict'].keys():
            facial_map_accuracy_dict_all[key] = 0.0
            facial_map_precision_dict_all[key] = 0.0
            facial_map_gt_dict[key] = 0.0

        index = np.arange(len(train_data))
        np.random.shuffle(index)
        train_data = train_data[index]

        tStart_epoch = time.time()

        loss_list = np.zeros(
            len(train_data
                ))  # each item in loss_epoch record the loss of this h5 file
        loss_ssad_list = np.zeros(len(train_data))
        loss_positive_loss_all_list = np.zeros(len(train_data))
        loss_hard_negative_loss_all_list = np.zeros(len(train_data))
        loss_easy_negative_loss_all_list = np.zeros(len(train_data))
        loss_smooth_center_loss_all_list = np.zeros(len(train_data))
        loss_smooth_width_loss_all_list = np.zeros(len(train_data))
        loss_facial_positive_loss_all_list = np.zeros(len(train_data))
        loss_facial_negative_loss_all_list = np.zeros(len(train_data))

        for current_batch_file_idx in range(len(train_data)):
            print('\r ',
                  current_batch_file_idx,
                  '/',
                  len(train_data),
                  end='    ')
            #logging.info("current_batch_file_idx = {:d}".format(current_batch_file_idx))
            #logging.info(train_data[current_batch_file_idx])

            tStart = time.time()
            current_batch = h5py.File(train_data[current_batch_file_idx], 'r')

            # processing sentence
            current_captions_tmp = current_batch['sentence']
            current_captions = []
            for ind in range(options['batch_size']):
                current_captions.append(bytes.decode(
                    current_captions_tmp[ind]))
            current_captions = np.array(current_captions)
            for ind in range(options['batch_size']):
                for c in string.punctuation:
                    current_captions[ind] = current_captions[ind].replace(
                        c, '')
            for i in range(options['batch_size']):
                current_captions[i] = current_captions[i].strip()
                if current_captions[i] == '':
                    current_captions[i] = '.'
            current_caption_ind = list(
                map(
                    lambda cap: [
                        wordtoix[word] for word in cap.lower().split(' ')
                        if word in wordtoix
                    ], current_captions))
            current_caption_matrix = sequence.pad_sequences(
                current_caption_ind,
                padding='post',
                maxlen=options['max_sen_len'] - 1)
            current_caption_matrix = np.hstack([
                current_caption_matrix,
                np.zeros([len(current_caption_matrix), 1])
            ]).astype(int)
            current_caption_length = np.array(
                list(map(lambda x: (x != 0).sum(), current_caption_matrix))
            )  # save the sentence length of this batch

            # processing video
            current_video_name = current_batch['video_name']

            current_facial_map = generate_batch_facial_map(current_video_name)

            current_anchor_input = np.array(current_batch['anchor_input'])
            #print('***********************************************current_anchor_input.shape',np.shape(current_anchor_input))
            current_video_feats, current_video_duration = generate_batch_video_fts(
                current_video_name)

            _,  loss, loss_ssad, positive_loss_all, hard_negative_loss_all, easy_negative_loss_all,\
                smooth_center_loss_all, smooth_width_loss_all, loss_regular, facial_positive_loss_all, \
                facial_negative_loss_all,possitive_loss, possitive_num, negative_loss, negative_num, facial_map_accuracy_dict, facial_map_precision_dict,predict_label_map = sess.run(
                    [train_op, t_loss, t_loss_ssad , t_positive_loss_all, t_hard_negative_loss_all, \
                     t_easy_negative_loss_all, t_smooth_center_loss_all, t_smooth_width_loss_all, t_loss_regular,\
                     t_facial_positive_loss_all, t_facial_negative_loss_all,t_facial_possitive_loss,t_facial_possitive_num,\
                     t_facial_negative_loss,t_facial_negative_num,t_facial_map_accuracy_dict,t_facial_map_precision_dict,t_predict_label_map], \
                    feed_dict={
                        t_feature_segment: current_video_feats,
                        t_sentence_index_placeholder: current_caption_matrix,
                        t_sentence_w_len: current_caption_length,
                        t_gt_overlap: current_anchor_input,
                        t_gt_facial_map: current_facial_map
                        })
            for key in options['facial_dict'].keys():
                facial_map_accuracy_dict_all[
                    key] = facial_map_accuracy_dict_all[
                        key] + facial_map_accuracy_dict[key]
                facial_map_precision_dict_all[
                    key] = facial_map_precision_dict_all[
                        key] + facial_map_precision_dict[key]
                for i in range(len(options['feature_map_len'])):
                    facial_map_gt_dict[key] += np.sum(
                        current_facial_map[:, i:i +
                                           1, :options['feature_map_len'][i],
                                           int(options['facial_dict'][key]
                                               ):len(options[
                                                   'scale_ratios_anchor%d' %
                                                   (i + 1)]) *
                                           len(options['facial_dict']
                                               ):len(options['facial_dict'])])

            loss_list[current_batch_file_idx] = loss
            loss_ssad_list[current_batch_file_idx] = loss_ssad
            loss_positive_loss_all_list[
                current_batch_file_idx] = positive_loss_all
            loss_hard_negative_loss_all_list[
                current_batch_file_idx] = hard_negative_loss_all
            loss_easy_negative_loss_all_list[
                current_batch_file_idx] = easy_negative_loss_all
            loss_smooth_center_loss_all_list[
                current_batch_file_idx] = smooth_center_loss_all
            loss_smooth_width_loss_all_list[
                current_batch_file_idx] = smooth_width_loss_all
            loss_facial_positive_loss_all_list[
                current_batch_file_idx] = facial_positive_loss_all
            loss_facial_negative_loss_all_list[
                current_batch_file_idx] = facial_negative_loss_all

            #logging.info("loss = {:f} loss_ssad = {:f} loss_regular = {:f} positive_loss_all = {:f} hard_negative_loss_all = {:f} easy_negative_loss_all = {:f} smooth_center_loss_all = {:f} smooth_width_loss_all = {:f}".format(loss, loss_ssad, loss_regular, positive_loss_all, hard_negative_loss_all, easy_negative_loss_all, smooth_center_loss_all, smooth_width_loss_all))

        if finetune:
            logging.info("Epoch: {:d} done.".format(epoch + start_epoch))
        else:
            logging.info("Epoch: {:d} done.".format(epoch))
        tStop_epoch = time.time()
        logging.info('Epoch Time Cost: {:f} s'.format(
            round(tStop_epoch - tStart_epoch, 2)))

        logging.info('Current Epoch Mean loss {:f}'.format(np.mean(loss_list)))
        logging.info('Current Epoch Mean loss_ssad {:f}'.format(
            np.mean(loss_ssad_list)))
        logging.info('Current Epoch Mean positive_loss_all {:f}'.format(
            np.mean(loss_positive_loss_all_list)))
        logging.info('Current Epoch Mean hard_negative_loss_all {:f}'.format(
            np.mean(loss_hard_negative_loss_all_list)))
        logging.info('Current Epoch Mean easy_negative_loss_all {:f}'.format(
            np.mean(loss_easy_negative_loss_all_list)))
        logging.info('Current Epoch Mean smooth_center_loss_all {:f}'.format(
            np.mean(loss_smooth_center_loss_all_list)))
        logging.info('Current Epoch Mean smooth_width_loss_all {:f}'.format(
            np.mean(loss_smooth_width_loss_all_list)))
        logging.info('Current Epoch Mean facial_positive_loss_all {:f}'.format(
            np.mean(loss_facial_positive_loss_all_list)))
        logging.info(
            'Current Epoch Mean negative_negative_loss_all {:f}'.format(
                np.mean(loss_facial_negative_loss_all_list)))
        for key in options['facial_dict'].keys():
            logging.info(
                'Current Epoch classifier-{:s} accuracy-{:.4f}  precision-{:.4f}  GT-{:f}'
                .format(
                    key, facial_map_accuracy_dict_all[key] /
                    (len(train_data) * options['batch_size'] *
                     sum(options['feature_map_len']) *
                     len(options['scale_ratios_anchor1'])),
                    facial_map_precision_dict_all[key] /
                    max(0.1, facial_map_gt_dict[key]),
                    facial_map_gt_dict[key]))

        loss_all_dict['mean_loss'].append(np.mean(loss_list))
        loss_all_dict['mean_ssad'].append(np.mean(loss_ssad_list))
        loss_all_dict['mean_positive_loss'].append(
            np.mean(loss_positive_loss_all_list))
        loss_all_dict['mean_hard_negative_loss'].append(
            np.mean(loss_hard_negative_loss_all_list))
        loss_all_dict['mean_easy_negative_loss'].append(
            np.mean(loss_easy_negative_loss_all_list))
        loss_all_dict['mean_smooth_center_loss'].append(
            np.mean(loss_smooth_center_loss_all_list))
        loss_all_dict['mean_smooth_width_loss'].append(
            np.mean(loss_smooth_width_loss_all_list))
        loss_all_dict['mean_facial_positive_loss'].append(
            np.mean(loss_facial_positive_loss_all_list))
        loss_all_dict['mean_facial_negative_loss'].append(
            np.mean(loss_facial_negative_loss_all_list))
        #################################################### test ################################################################################################
        if finetune:
            now_epoch = epoch + start_epoch
        else:
            now_epoch = epoch

        if np.mod(epoch, 1) == 0 and epoch > 4:
            if finetune:
                logging.info('Epoch {:d} is done. Saving the model ...'.format(
                    now_epoch))
            else:
                logging.info('Epoch {:d} is done. Saving the model ...'.format(
                    now_epoch))

            with tf.device("/cpu:0"):
                if finetune:
                    saver.save(sess,
                               os.path.join(model_save_dir, 'model'),
                               global_step=now_epoch)
                else:
                    saver.save(sess,
                               os.path.join(model_save_dir, 'model'),
                               global_step=now_epoch)

    logging.info("Finally, saving the model ...")
    with tf.device("/cpu:0"):
        if finetune:
            saver.save(sess,
                       os.path.join(model_save_dir, 'model'),
                       global_step=now_epoch)
        else:
            saver.save(sess,
                       os.path.join(model_save_dir, 'model'),
                       global_step=now_epoch)

    np.save(os.path.join(options['result_save_dir'] + task, 'epoch_loss.npy'),
            loss_all_dict)
    tStop_total = time.time()
    logging.info("Total Time Cost: {:f} s".format(
        round(tStop_total - tStart_total, 2)))