예제 #1
0
def train(epoch, dataset, config, log_dir):
    """Train model for one epoch."""
    model_config = config['model']
    train_config = config['train']
    sess_config = config['session']

    with tf.Graph().as_default():
        model = GRA(model_config)
        model.build_inference()
        model.build_loss(train_config['reg_coeff'], train_config['shu_coeff'])
        model.build_train(train_config['learning_rate'])

        with tf.Session(config=sess_config) as sess:
            sum_dir = os.path.join(log_dir, 'summary')
            # create event file for graph
            if not os.path.exists(sum_dir):
                summary_writer = tf.summary.FileWriter(sum_dir, sess.graph)
                summary_writer.close()
            summary_writer = tf.summary.FileWriter(sum_dir)

            ckpt_dir = os.path.join(log_dir, 'checkpoint')
            ckpt_path = tf.train.latest_checkpoint(ckpt_dir)
            saver = tf.train.Saver()
            if ckpt_path:
                print('load checkpoint {}.'.format(ckpt_path))
                lajidaima = int(ckpt_path.split('-')[-1]) - epoch + 1
                saver.restore(sess, ckpt_path)
            else:
                print('no checkpoint.')
                if not os.path.exists(ckpt_dir):
                    os.makedirs(ckpt_dir)
                sess.run(tf.global_variables_initializer())
            epoch += lajidaima
            stats_dir = os.path.join(log_dir, 'stats')
            stats_path = os.path.join(stats_dir, 'train.json')
            if os.path.exists(stats_path):
                print('load stats file {}.'.format(stats_path))
                stats = pd.read_json(stats_path, 'records')
            else:
                print('no stats file.')
                if not os.path.exists(stats_dir):
                    os.makedirs(stats_dir)
                stats = pd.DataFrame(columns=['epoch', 'loss', 'acc'])

            # train iterate over batch
            batch_idx = 0
            total_loss = 0
            total_acc = 0
            batch_total = np.sum(dataset.train_batch_total)

            while dataset.has_train_batch:
                vgg, c3d, question, answer = dataset.get_train_batch()
                vgg = np.zeros((len(vgg), len(vgg[0]), len(vgg[0][0])))
                feed_dict = {
                    model.appear: vgg,
                    model.motion: c3d,
                    model.question_encode: question,
                    model.answer_encode: answer
                }
                _, loss, prediction = sess.run(
                    [model.train, model.loss, model.prediction], feed_dict)

                # cal acc
                correct = 0
                for i, row in enumerate(prediction[1]):
                    #print(row)
                    for index in row:
                        if answer[i][index] == 1:
                            correct += 1
                            break
                acc = correct / len(answer)

                total_loss += loss
                total_acc += acc
                if batch_idx % 10 == 0:
                    print(
                        '[TRAIN] epoch {}, batch {}/{}, loss {:.5f}, acc {:.5f}.'
                        .format(epoch, batch_idx, batch_total, loss, acc))
                batch_idx += 1

            loss = total_loss / batch_total
            acc = total_acc / batch_total
            print('\n[TRAIN] epoch {}, loss {:.5f}, acc {:.5f}.\n'.format(
                epoch, loss, acc))

            summary = tf.Summary()
            summary.value.add(tag='train/loss', simple_value=float(loss))
            summary.value.add(tag='train/acc', simple_value=float(acc))
            summary_writer.add_summary(summary, epoch)

            record = Series([epoch, loss, acc], ['epoch', 'loss', 'acc'])
            stats = stats.append(record, ignore_index=True)

            saver.save(sess, os.path.join(ckpt_dir, 'model.ckpt'), epoch)
            stats.to_json(stats_path, 'records')
            dataset.reset_train()
            return loss, acc
예제 #2
0
def val(epoch, dataset, config, log_dir):
    """Validate model."""
    model_config = config['model']
    sess_config = config['session']
    train_config = config['train']

    answerset = pd.read_csv(
        os.path.join(config['preprocess_dir'], 'answer_set.txt'), header=None)[0]

    example_id = 0

    with tf.Graph().as_default():
        model = GRA(model_config)
        model.build_inference()
        model.build_loss(train_config['reg_coeff'], train_config['shu_coeff'])

        result = DataFrame(columns=['id', 'answer'])
        with tf.Session(config=sess_config) as sess:
            sum_dir = os.path.join(log_dir, 'summary')
            summary_writer = tf.summary.FileWriter(sum_dir)

            ckpt_dir = os.path.join(log_dir, 'checkpoint')
            save_path = tf.train.latest_checkpoint(ckpt_dir)
            saver = tf.train.Saver()
            last_epoch = 0
            if save_path:
                print('load checkpoint {}.'.format(save_path))
                last_epoch = int(save_path.split('-')[-1]) - epoch
                saver.restore(sess, save_path)
            else:
                print('no checkpoint.')
                exit()

            stats_dir = os.path.join(log_dir, 'stats')
            stats_path = os.path.join(stats_dir, 'val.json')
            if os.path.exists(stats_path):
                print('load stats file {}.'.format(stats_path))
                stats = pd.read_json(stats_path, 'records')
            else:
                print('no stats file.')
                if not os.path.exists(stats_dir):
                    os.makedirs(stats_dir)
                stats = pd.DataFrame(columns=['epoch', 'acc'])

            # val iterate over examples
            correct = 0
            loss_total = 0
            while dataset.has_val_example:
                vgg, c3d, question, answer = dataset.get_val_example()
                feed_dict = {
                    model.appear: [vgg],
                    model.question_encode: [question],
                    model.answer_encode: [answer]
                }
                loss, prediction = sess.run([model.loss, model.prediction], feed_dict=feed_dict)
                loss_total += loss
                prediction = prediction[1]
                for i, row in enumerate(prediction):
                    for index in row:
                        if answer[index] == 1:
                            correct += 1
                            break
                result = result.append({'id': example_id, 'answer': prediction}, ignore_index=True)
                example_id += 1
            acc = correct / dataset.val_example_total
            loss = loss_total / dataset.val_example_total
            result.to_json(os.path.join(
                log_dir, 'validation_' + str(int(acc * 100)) + '_'  + str(epoch + last_epoch)  +  '.json'), 'records')
            print('\n[VAL] epoch {}, acc {:.5f}.\n'.format(epoch + last_epoch, acc))

            summary = tf.Summary()
            summary.value.add(tag='val/acc', simple_value=float(acc))
            summary.value.add(tag='val/loss', simple_value=float(loss))
            summary_writer.add_summary(summary, epoch + last_epoch)

            record = Series([epoch + last_epoch, acc, loss], ['epoch', 'acc', 'loss'])
            stats = stats.append(record, ignore_index=True)
            stats.to_json(stats_path, 'records')

            dataset.reset_val()
            return acc