Пример #1
0
def test(args):
    model = NTMOneShotLearningModel(args)
    data_loader = OmniglotDataLoader(args)
    saver = tf.train.Saver()
    ckpt = tf.train.get_checkpoint_state(args.save_dir + '/' + args.model +
                                         '_' + args.label_type)
    with tf.Session() as sess:
        saver.restore(sess, ckpt.model_checkpoint_path)
        print(
            "Test Result\n1st\t2nd\t3rd\t4th\t5th\t6th\t7th\t8th\t9th\t10th\tloss"
        )
        y_list = []
        output_list = []
        loss_list = []
        for episode in range(args.test_batch_num):
            x_image, x_label, y = data_loader.fetch_batch(
                args,
                mode='test',
                augment=args.augment,
                sample_strategy=args.sample_strategy)
            feed_dict = {
                model.x_image: x_image,
                model.x_label: x_label,
                model.y: y
            }
            output, learning_loss = sess.run([model.o, model.loss],
                                             feed_dict=feed_dict)
            y_list.append(y)
            output_list.append(output)
            loss_list.append(learning_loss)
        accuracy = compute_accuracy(args, np.concatenate(y_list, axis=0),
                                    np.concatenate(output_list, axis=0))
        for accu in accuracy:
            print('%.4f' % accu)
        print(np.mean(loss_list))
Пример #2
0
def test(args):
    model = NTMOneShotLearningModel(args)
    data_loader = OmniglotDataLoader(
        image_size=(args.image_width, args.image_height),
        n_train_classses=args.n_train_classes,
        n_test_classes=args.n_test_classes
    )
    saver = tf.train.Saver()
    ckpt = tf.train.get_checkpoint_state(args.save_dir + '/' + args.model)
    with tf.Session() as sess:
        saver.restore(sess, ckpt.model_checkpoint_path)
        print("Test Result\n1st\t2nd\t3rd\t4th\t5th\t6th\t7th\t8th\t9th\t10th\tloss")
        y_list = []
        output_list = []
        loss_list = []
        for b in range(args.test_batch_num):
            x_image, x_label, y = data_loader.fetch_batch(args.n_classes, args.batch_size, args.seq_length,
                                                          type='test',
                                                          augment=args.augment,
                                                          label_type=args.label_type)
            feed_dict = {model.x_image: x_image, model.x_label: x_label, model.y: y}
            output, learning_loss = sess.run([model.o, model.learning_loss], feed_dict=feed_dict)
            y_list.append(y)
            output_list.append(output)
            loss_list.append(learning_loss)
        accuracy = test_f(args, np.concatenate(y_list, axis=0), np.concatenate(output_list, axis=0))
        for accu in accuracy:
            print('%.4f' % accu, end='\t')
        print(np.mean(loss_list))
Пример #3
0
def train(args):
    model = NTMOneShotLearningModel(args)
    data_loader = OmniglotDataLoader(
        image_size=(args.image_width, args.image_height),
        n_train_classses=args.n_train_classes,
        n_test_classes=args.n_test_classes
    )
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        if args.debug:
            sess = tf_debug.LocalCLIDebugWrapperSession(sess)
        if args.restore_training:
            saver = tf.train.Saver()
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            saver = tf.train.Saver(tf.global_variables())
            tf.global_variables_initializer().run()
        train_writer = tf.summary.FileWriter(args.tensorboard_dir + '/' + args.model, sess.graph)
        print(args)
        print("1st\t2nd\t3rd\t4th\t5th\t6th\t7th\t8th\t9th\t10th\tbatch\tloss\tponder")
        for b in range(args.num_epoches):

            # Test

            if b % 100 == 0:
                x_image, x_label, y = data_loader.fetch_batch(args.n_classes, args.batch_size, args.seq_length,
                                                              type='test',
                                                              augment=args.augment,
                                                              label_type=args.label_type)
                feed_dict = {model.x_image: x_image, model.x_label: x_label, model.y: y}
                output, learning_loss, ponder_steps = sess.run([model.o, model.learning_loss, model.mean_ponder_steps], feed_dict=feed_dict)
                merged_summary = sess.run(model.learning_loss_summary, feed_dict=feed_dict)
                train_writer.add_summary(merged_summary, b)
                # state_list = sess.run(model.state_list, feed_dict=feed_dict)  # For debugging
                # with open('state_long.txt', 'w') as f:
                #     print(state_list, file=f)
                accuracy = test_f(args, y, output)
                for accu in accuracy:
                    print('%.4f' % accu, end='\t')
                print('%d\t%.3f\t%.2f' % (b, learning_loss, ponder_steps))

            # Save model

            if b % 5000 == 0 and not args.no_save_model:
                saver.save(sess, save_path + '/model.tfmodel', global_step=b)

            # Train

            x_image, x_label, y = data_loader.fetch_batch(args.n_classes, args.batch_size, args.seq_length,
                                                          type='train',
                                                          sample_strategy=args.sample_strategy,
                                                          augment=args.augment,
                                                          label_type=args.label_type)
            feed_dict = {model.x_image: x_image, model.x_label: x_label, model.y: y}
            sess.run(model.train_op, feed_dict=feed_dict)
Пример #4
0
def train(args):
    eprint("Args: ", args)
    exp_name = gen_exp_name(args)
    eprint(exp_name)
    eprint("Loading in Model")
    model = NTMOneShotLearningModel(args)
    eprint("Loading Data")
    if args.dataset_type == 'omniglot':
        data_loader = OmniglotDataLoader(image_size=(args.image_width,
                                                     args.image_height),
                                         n_train_classes=args.n_train_classes,
                                         n_test_classes=args.n_test_classes,
                                         data_dir=args.data_dir)
        test_data_loader = data_loader
    elif args.dataset_type == 'kinetics_dynamic':
        data_loader = InputLoader('dynamic_image',
                                  'train',
                                  im_size=args.image_height,
                                  args=args)
        test_data_loader = InputLoader('dynamic_image',
                                       'val',
                                       im_size=args.image_height,
                                       args=args)
    elif args.dataset_type == 'kinetics_video':
        data_loader = InputLoader('raw_video',
                                  'train',
                                  args=args,
                                  im_size=args.image_height)
        test_data_loader = InputLoader('raw_video',
                                       'val',
                                       args=args,
                                       im_size=args.image_height)
    elif args.dataset_type == 'kinetics_single_frame':
        data_loader = InputLoader('single_frame',
                                  'train',
                                  args=args,
                                  im_size=args.image_height)
        test_data_loader = InputLoader('single_frame',
                                       'val',
                                       args=args,
                                       im_size=args.image_height)

    eprint("Starting Session")
    with tf.Session() as sess:
        eprint("Started Session")
        if args.tf_debug_flag:
            sess = tf_debug.LocalCLIDebugWrapperSession(sess)
        if args.restore_training:
            saver = tf.train.Saver()
            ckpt = tf.train.get_checkpoint_state(args.save_dir + '/' +
                                                 args.model)
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            if args.model_saver:
                eprint("Starting saver")
                saver = tf.train.Saver(tf.global_variables())
                eprint("Finished saver")
            tf.global_variables_initializer().run()
            eprint("Finished Initialization")
        if args.summary_writer:
            train_writer = tf.summary.FileWriter(
                args.tensorboard_dir + '/' + args.model, sess.graph)
            eprint("Train Writer Finished")
        eprint(args)
        eprint(
            "1st\t2nd\t3rd\t4th\t5th\t6th\t7th\t8th\t9th\t10th\tbatch\tloss")

        rand_val = str(np.random.randint(0, 100))

        loss_list = []
        accuracy_list = []
        for b in range(args.num_epoches):

            # Test

            if b % args.validation_freq == 0:
                learning_loss_list, output_list, y_list = [], [], []
                for val_index in range(args.batches_validation):
                    if args.dataset_type == 'omniglot':
                        x_image, x_label, y = data_loader.fetch_batch(
                            args.n_classes,
                            args.batch_size,
                            args.seq_length,
                            type='test',
                            sampling_strategy=args.sampling_strategy,
                            augment=args.augment,
                            label_type=args.label_type)
                    else:
                        # all kinetics loaders
                        if args.debug:
                            eprint("Loading in validation data")
                        x_image, x_label, y = test_data_loader.fetch_batch(
                            args.n_classes,
                            args.batch_size,
                            args.seq_length,
                            sampling_strategy=args.sampling_strategy,
                            augment=args.augment,
                            label_type=args.label_type)
                    feed_dict = {
                        model.x_image: x_image,
                        model.x_label: x_label,
                        model.y: y,
                        model.is_training: False
                    }
                    if args.debug:
                        eprint("Validation running session")

                    output, learning_loss = sess.run(
                        [model.o, model.learning_loss], feed_dict=feed_dict)
                    if args.summary_writer:
                        merged_summary = sess.run(model.learning_loss_summary,
                                                  feed_dict=feed_dict)
                        train_writer.add_summary(merged_summary, b)
                    learning_loss_list.append(learning_loss)
                    output_list.append(output)
                    y_list.append(y)
                # state_list = sess.run(model.state_list, feed_dict=feed_dict)  # For debugging
                # with open('state_long.txt', 'w') as f:
                #     print(state_list, file=f)
                output_total = np.concatenate(output_list, axis=0)
                y_total = np.concatenate(y_list, axis=0)
                learning_loss = np.mean(learning_loss_list)
                #  set_trace()
                accuracy, total = test_f(args, y_total, output_total)
                eprint(end='\t')
                for accu in accuracy:
                    eprint2('%.4f' % accu, end='\t')
                eprint2(end='\n')
                eprint(end='\t')
                for tot in total:
                    eprint2('%d' % tot, end='\t')

                eprint2('%d\t%.4f' % (b, learning_loss))

                if args.serialize:
                    accuracy_list.append(accuracy)
                    dir_name = args.save_dir + '/' + exp_name
                    mkdir(dir_name)
                    eprint("Serializing intermediate accuracy: " + rand_val,
                           " exp_name", dir_name)
                    serialize_plot(accuracy_list, dir_name,
                                   "inter_accuracy" + rand_val)

            # Save model

            if args.model_saver and b % args.model_save_freq == 0 and b > 0:
                if args.debug:
                    eprint(
                        "saving to: {}".format(args.save_dir + '/' +
                                               args.model + '/model.tfmodel'))

                mkdir(args.save_dir + '/' + args.model)
                saver.save(sess,
                           args.save_dir + '/' + args.model + '/model.tfmodel',
                           global_step=b)

            # Train
            if args.debug:
                eprint("[{}] Fetch Batch".format(b))
            if args.dataset_type == 'omniglot':
                x_image, x_label, y = data_loader.fetch_batch(
                    args.n_classes,
                    args.batch_size,
                    args.seq_length,
                    type='train',
                    sampling_strategy=args.sampling_strategy,
                    augment=args.augment,
                    label_type=args.label_type)
            else:
                x_image, x_label, y = data_loader.fetch_batch(
                    args.n_classes,
                    args.batch_size,
                    args.seq_length,
                    sampling_strategy=args.sampling_strategy,
                    augment=args.augment,
                    label_type=args.label_type)

            if args.debug:
                eprint("[{}] Run Sess".format(b))
            #  feed_dict = {model.x_image: x_image, model.x_label: x_label, model.y: y}
            feed_dict = {
                model.x_image: x_image,
                model.x_label: x_label,
                model.y: y,
                model.is_training: False
            }
            learning_loss, _ = sess.run([model.learning_loss, model.train_op],
                                        feed_dict=feed_dict)
            if args.debug:
                eprint("[{}] Learning Loss: {:.3f}".format(b, learning_loss))
            if args.serialize:
                loss_list.append(learning_loss)
                dir_name = args.save_dir + '/' + exp_name
                mkdir(dir_name)
                eprint("Serializing intermediate loss")
                serialize_plot(loss_list, dir_name, "inter_loss" + rand_val)

    if args.serialize:
        dir_name = args.save_dir + '/' + exp_name
        mkdir(dir_name)
        # rand_val = str(np.random.randint(0, 100))
        eprint("Serializing: ", rand_val)
        serialize_plot(loss_list, dir_name, "loss" + rand_val)
        serialize_plot(accuracy_list, dir_name, "accuracy" + rand_val)
        serialize_plot(args, dir_name, "arguments" + rand_val)
        eprint("Saved all plots")
Пример #5
0
def train(args):
    model = NTMOneShotLearningModel(args)
    data_loader = OmniglotDataLoader(args)
    os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1, 2, 3"

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)
    if not os.path.exists(args.save_dir + '/' + args.model + '_' +
                          args.label_type):
        os.makedirs(args.save_dir + '/' + args.model + '_' + args.label_type)

    with tf.Session() as sess:
        if args.restore_training:
            saver = tf.train.Saver()
            ckpt = tf.train.get_checkpoint_state(args.save_dir + '/' +
                                                 args.model + '_' +
                                                 args.label_type)
            saver.restore(sess, ckpt.model_checkpoint_path)
            last_episode = int(str(ckpt.model_checkpoint_path).split('-')[-1])
            all_acc_train, all_loss_train = load_results(args,
                                                         last_episode,
                                                         mode='train')
            all_acc_test, all_loss_test = load_results(args,
                                                       last_episode,
                                                       mode='test')
        else:
            saver = tf.train.Saver(tf.global_variables())
            tf.global_variables_initializer().run()
            all_acc_train = all_acc_test = np.zeros(
                (0, args.seq_length / args.n_classes))
            all_loss_train = all_loss_test = np.array([])

        train_writer = tf.summary.FileWriter(
            args.tensorboard_dir + args.model + '_' + args.label_type +
            '/train/', sess.graph)
        test_writer = tf.summary.FileWriter(args.tensorboard_dir + args.model +
                                            '_' + args.label_type + '/test/')
        print(
            '---------------------------------------------------------------------------------------------'
        )
        print(args)
        print(
            '---------------------------------------------------------------------------------------------'
        )
        print(
            "1st\t2nd\t3rd\t4th\t5th\t6th\t7th\t8th\t9th\t10th\tepisode\tloss")
        for episode in range(args.num_episodes):

            # Train
            x_image, x_label, y = data_loader.fetch_batch(
                args,
                mode='train',
                augment=args.augment,
                sample_strategy=args.sample_strategy)
            feed_dict = {
                model.x_image: x_image,
                model.x_label: x_label,
                model.y: y
            }
            sess.run(model.train_op, feed_dict=feed_dict)
            if episode % args.disp_freq == 0 and episode > 0:
                output, train_loss = sess.run([model.o, model.loss],
                                              feed_dict=feed_dict)
                summary_train = sess.run(model.loss_summary,
                                         feed_dict=feed_dict)
                train_writer.add_summary(summary_train, episode)
                train_acc = compute_accuracy(args, y, output)
                all_acc_train, all_loss_train = display_and_save(
                    args,
                    all_acc_train,
                    train_acc,
                    all_loss_train,
                    train_loss,
                    episode,
                    mode='train')

            # Test
            if episode % args.test_freq == 0 and episode > 0:
                x_image, x_label, y = data_loader.fetch_batch(
                    args,
                    mode='test',
                    augment=args.augment,
                    sample_strategy=args.sample_strategy)
                feed_dict = {
                    model.x_image: x_image,
                    model.x_label: x_label,
                    model.y: y
                }
                output, test_loss = sess.run([model.o, model.loss],
                                             feed_dict=feed_dict)
                summary_test = sess.run(model.loss_summary,
                                        feed_dict=feed_dict)
                test_writer.add_summary(summary_test, episode)
                test_acc = compute_accuracy(args, y, output)
                all_acc_test, all_loss_test = display_and_save(args,
                                                               all_acc_test,
                                                               test_acc,
                                                               all_loss_test,
                                                               test_loss,
                                                               episode,
                                                               mode='test')

            # Save model
            if episode % args.save_freq == 0 and episode > 0:
                saver.save(sess,
                           args.save_dir + '/' + args.model + '_' +
                           args.label_type + '/model.tfmodel',
                           global_step=episode)
Пример #6
0
def train(args):
    model = NTMOneShotLearningModel(args)

    data_loader = OmniglotDataLoader(image_size=(args.image_width,
                                                 args.image_height),
                                     n_train_classses=args.n_train_classes,
                                     n_test_classes=args.n_test_classes)

    with tf.Session() as sess:
        if args.debug:
            sess = tf_debug.LocalCLIDebugWrapperSession(sess)

        if args.restore_training:
            saver = tf.train.Saver()
            ckpt = tf.train.get_checkpoint_state(args.save_dir + '/' +
                                                 args.model)
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            saver = tf.train.Saver(tf.global_variables())
            tf.global_variables_initializer().run()

        train_writer = tf.summary.FileWriter(
            args.tensorboard_dir + '/' + args.model, sess.graph)
        #print(args)
        #print("1st\t2nd\t3rd\t4th\t5th\t6th\t7th\t8th\t9th\t10th\tbatch\tloss")
        for b in range(10000000):

            print("Inside the Iteration")

            # Test

            if b % 100 == 0:  #this is more of validating the model
                x_image, x_label, y = data_loader.fetch_batch(
                    args.n_classes,
                    args.batch_size,
                    args.
                    seq_length,  #we do not optimize this just take the loss
                    type='test',
                    augment=args.augment,
                    label_type=args.label_type)
                feed_dict = {
                    model.x_image: x_image,
                    model.x_label: x_label,
                    model.y: y
                }  #feeding to the algorithm
                output, learning_loss = sess.run(
                    [model.o, model.learning_loss], feed_dict=feed_dict
                )  #get the output predicted  also training list

                merged_summary = sess.run(
                    model.learning_loss_summary,
                    feed_dict=feed_dict)  #this is to tensorbards
                train_writer.add_summary(merged_summary, b)

                # state_list = sess.run(model.state_list, feed_dict=feed_dict)  # For debugging
                # with open('state_long.txt', 'w') as f:
                #     print(state_list, file=f)
                accuracy = test_f(
                    args, y, output
                )  #getting the accuracy need to inpit the output of the network and read output y
                for accu in accuracy:
                    print(accu)
                print((b, learning_loss))

            # Save model

            if b % 5000 == 0 and b > 0:
                saver.save(sess,
                           args.save_dir + '/' + args.model + '/model.tfmodel',
                           global_step=b)

            # Train

            x_image, x_label, y = data_loader.fetch_batch(
                args.n_classes,
                args.batch_size,
                args.seq_length,
                type='train',
                augment=args.augment,
                label_type=args.label_type)
            feed_dict = {
                model.x_image: x_image,
                model.x_label: x_label,
                model.y: y
            }
            learning_loss = sess.run([model.learning_loss],
                                     feed_dict=feed_dict)
            sess.run(model.train_op, feed_dict=feed_dict)
            print(learning_loss, "Learing losss")
Пример #7
0
def train(args):
    model = NTMOneShotLearningModel(args)
    data_loader = OmniglotDataLoader(
        data_dir=
        "/Users/xavier.qiu/Documents/ricecourse/comp590Research/data/omniglot/images_background/",
        image_size=(args.image_width, args.image_height),
        n_train_classses=args.n_train_classes,
        n_test_classes=args.n_test_classes)
    with tf.Session() as sess:
        if args.debug:
            sess = tf_debug.LocalCLIDebugWrapperSession(sess)
        if args.restore_training:
            saver = tf.train.Saver()
            ckpt = tf.train.get_checkpoint_state(args.save_dir + '/' +
                                                 args.model)
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            saver = tf.train.Saver(tf.global_variables())
            tf.global_variables_initializer().run()
        train_writer = tf.summary.FileWriter(
            args.tensorboard_dir + '/' + args.model, sess.graph)
        print(args)
        print("1st\t2nd\t3rd\t4th\t5th\t6th\t7th\t8th\t9th\t10th\tbatch\tloss")
        for b in range(args.num_epoches):

            # Test

            if b % 100 == 0:
                x_image, x_label, y = data_loader.fetch_batch(
                    args.n_classes,
                    args.batch_size,
                    args.seq_length,
                    type='test',
                    augment=args.augment,
                    label_type=args.label_type)
                feed_dict = {
                    model.x_image: x_image,
                    model.x_label: x_label,
                    model.y: y
                }
                output, learning_loss = sess.run(
                    [model.o, model.learning_loss], feed_dict=feed_dict)
                merged_summary = sess.run(model.learning_loss_summary,
                                          feed_dict=feed_dict)
                train_writer.add_summary(merged_summary, b)
                # state_list = sess.run(model.state_list, feed_dict=feed_dict)  # For debugging
                # with open('state_long.txt', 'w') as f:
                #     print(state_list, file=f)
                accuracy = test_f(args, y, output)
                for accu in accuracy:
                    print('%.4f' % accu, end='\t')
                print('%d\t%.4f' % (b, learning_loss))

            # Save model

            if b % 5000 == 0 and b > 0:
                saver.save(sess,
                           args.save_dir + '/' + args.model + '/model.tfmodel',
                           global_step=b)

            # Train

            x_image, x_label, y = data_loader.fetch_batch(
                args.n_classes,
                args.batch_size,
                args.seq_length,
                type='train',
                augment=args.augment,
                label_type=args.label_type)
            feed_dict = {
                model.x_image: x_image,
                model.x_label: x_label,
                model.y: y
            }
            sess.run(model.train_op, feed_dict=feed_dict)