コード例 #1
0
ファイル: main.py プロジェクト: isoberkeley/sum-var
def main():
    if FLAGS.train:
        test_num_updates = FLAGS.num_updates
    else:
        test_num_updates = 5
    data_generator = DataGenerator()
    data_generator.generate_time_series_batch(train=FLAGS.train)
    model = MAML(data_generator.batch_size, test_num_updates)
    model.construct_model(input_tensors=None, prefix='metatrain_')
    model.summ_op = tf.summary.merge_all()
    saver = loader = tf.train.Saver(
        tf.get_collection(
            tf.GraphKeys.TRAINABLE_VARIABLES),
        max_to_keep=10)
    sess = tf.InteractiveSession()

    exp_string = FLAGS.train_csv_file + '.numstep' + str(test_num_updates) + '.updatelr' + str(FLAGS.meta_lr)


    resume_itr = 0
    model_file = None

    tf.global_variables_initializer().run()
    tf.train.start_queue_runners()

    if FLAGS.resume or not FLAGS.train:
        model_file = tf.train.latest_checkpoint(
            FLAGS.logdir + '/' + exp_string)
        print(model_file)
        if FLAGS.test_iter > 0:
            model_file = model_file[:model_file.index(
                'model')] + 'model' + str(FLAGS.test_iter)
        if model_file:
            ind1 = model_file.index('model')
            resume_itr = int(model_file[ind1 + 5:])
            print("Restoring model weights from " + model_file)
            saver.restore(sess, model_file)

    if FLAGS.train:
        train(model, saver, sess, exp_string, data_generator, resume_itr)
    else:
        test(model, sess, exp_string, data_generator)
コード例 #2
0
def main():
    if FLAGS.datasource == 'sinusoid':
        if FLAGS.train:
            test_num_updates = 5  # During base-testing (and thus meta updating) 5 updates are used
        else:
            test_num_updates = 10  # During meta-testing 10 updates are used
    else:
        if FLAGS.datasource == 'miniimagenet' or FLAGS.datasource == 'cifarfs':
            if FLAGS.train == True:
                test_num_updates = 1  # eval on at least one update during training
            else:
                test_num_updates = 10  # eval on 10 updates during testing
        else:
            test_num_updates = 10  # Omniglot gets 10 updates during training AND testing

    if FLAGS.train == False:
        orig_meta_batch_size = FLAGS.meta_batch_size
        # always use meta batch size of 1 when testing.
        FLAGS.meta_batch_size = 1

    if FLAGS.datasource == 'sinusoid':
        # DataGenerator(num_samples_per_class, batch_size, config={})
        data_generator = DataGenerator(FLAGS.update_batch_size * 2,
                                       FLAGS.meta_batch_size)
    else:  # Dealing with a non 'sinusoid' dataset here
        if FLAGS.metatrain_iterations == 0 and (
                FLAGS.datasource == 'miniimagenet'
                or FLAGS.datasource == 'cifarfs'):
            assert FLAGS.meta_batch_size == 1
            assert FLAGS.update_batch_size == 1
            data_generator = DataGenerator(
                1, FLAGS.meta_batch_size)  # only use one datapoint,
        else:
            if FLAGS.datasource == 'miniimagenet' or FLAGS.datasource == 'cifarfs':  # TODO - use 15 val examples for imagenet?
                if FLAGS.train:  # TODO: why +15 and *2 --> followin Ravi: "15 examples per class were used for evaluating the post-update meta-gradient" = MAML algo 2, line 10 --> see how 5 and 15 is split up in maml.py?
                    # DataGenerator(number_of_images_per_class, number_of_tasks_in_batch)
                    data_generator = DataGenerator(
                        FLAGS.update_batch_size + 15, FLAGS.meta_batch_size
                    )  # only use one datapoint for testing to save memory
                else:
                    data_generator = DataGenerator(
                        FLAGS.update_batch_size * 2, FLAGS.meta_batch_size
                    )  # only use one datapoint for testing to save memory
            else:  # this is for omniglot
                data_generator = DataGenerator(
                    FLAGS.update_batch_size * 2, FLAGS.meta_batch_size
                )  # only use one datapoint for testing to save memory

    dim_output = data_generator.dim_output  # number of classes, e.g. 5 for miniImagenet tasks
    if FLAGS.baseline == 'oracle':  # NOTE - this flag is specific to sinusoid
        assert FLAGS.datasource == 'sinusoid'
        dim_input = 3
        FLAGS.pretrain_iterations += FLAGS.metatrain_iterations
        FLAGS.metatrain_iterations = 0
    else:
        dim_input = data_generator.dim_input  # np.prod(self.img_size) for images

    if FLAGS.datasource == 'miniimagenet' or FLAGS.datasource == 'omniglot' or FLAGS.datasource == 'cifarfs':
        tf_data_load = True
        num_classes = data_generator.num_classes

        if FLAGS.train:  # only construct training model if needed
            # meta train : num_total_batches = 200000 (number of tasks, not number of meta-iterations)
            random.seed(5)
            image_tensor, label_tensor = data_generator.make_data_tensor()
            inputa = tf.slice(image_tensor, [0, 0, 0],
                              [-1, num_classes * FLAGS.update_batch_size, -1
                               ])  # slice(tensor, begin, slice_size)
            inputb = tf.slice(image_tensor,
                              [0, num_classes * FLAGS.update_batch_size, 0],
                              [-1, -1, -1])  # The extra 15 add here?!
            labela = tf.slice(label_tensor, [0, 0, 0],
                              [-1, num_classes * FLAGS.update_batch_size, -1])
            labelb = tf.slice(label_tensor,
                              [0, num_classes * FLAGS.update_batch_size, 0],
                              [-1, -1, -1])
            input_tensors = {
                'inputa': inputa,
                'inputb': inputb,
                'labela': labela,
                'labelb': labelb
            }

        # meta val: num_total_batches = 600 (number of tasks, not number of meta-iterations)
        random.seed(6)
        image_tensor, label_tensor = data_generator.make_data_tensor(
            train=False)
        inputa = tf.slice(image_tensor, [0, 0, 0],
                          [-1, num_classes * FLAGS.update_batch_size, -1
                           ])  # slice the training examples here
        inputb = tf.slice(image_tensor,
                          [0, num_classes * FLAGS.update_batch_size, 0],
                          [-1, -1, -1])
        labela = tf.slice(label_tensor, [0, 0, 0],
                          [-1, num_classes * FLAGS.update_batch_size, -1])
        labelb = tf.slice(label_tensor,
                          [0, num_classes * FLAGS.update_batch_size, 0],
                          [-1, -1, -1])
        metaval_input_tensors = {
            'inputa': inputa,
            'inputb': inputb,
            'labela': labela,
            'labelb': labelb
        }
    else:
        tf_data_load = False
        input_tensors = None

    model = MAML(
        dim_input, dim_output, test_num_updates=test_num_updates
    )  # test_num_updates = eval on at least one update for training, 10 testing
    if FLAGS.train or not tf_data_load:
        model.construct_model(input_tensors=input_tensors, prefix='metatrain_')
    if tf_data_load:
        model.construct_model(input_tensors=metaval_input_tensors,
                              prefix='metaval_')

    # Op to retrieve summaries?
    model.summ_op = tf.summary.merge_all()

    # keep last 10 copies of trainable variables
    saver = loader = tf.train.Saver(tf.get_collection(
        tf.GraphKeys.TRAINABLE_VARIABLES),
                                    max_to_keep=10)

    # remove the need to explicitly pass this Session object to run ops
    sess = tf.InteractiveSession()

    if FLAGS.train == False:
        # change to original meta batch size when loading model.
        FLAGS.meta_batch_size = orig_meta_batch_size

    if FLAGS.train_update_batch_size == -1:
        FLAGS.train_update_batch_size = FLAGS.update_batch_size
    if FLAGS.train_update_lr == -1:
        FLAGS.train_update_lr = FLAGS.update_lr

    # cls = no of classes
    # mbs = meta batch size
    # ubs = update batch size
    # numstep = number of INNER GRADIENT updates
    # updatelr = inner gradient step
    exp_string = 'cls_' + str(FLAGS.num_classes) + '.mbs_' + str(
        FLAGS.meta_batch_size) + '.ubs_' + str(
            FLAGS.train_update_batch_size) + '.numstep' + str(
                FLAGS.num_updates) + '.updatelr' + str(FLAGS.train_update_lr)

    if FLAGS.num_filters != 64:
        exp_string += 'hidden' + str(FLAGS.num_filters)
    if FLAGS.max_pool:
        exp_string += 'maxpool'
    if FLAGS.stop_grad:
        exp_string += 'stopgrad'
    if FLAGS.baseline:
        exp_string += FLAGS.baseline
    if FLAGS.norm == 'batch_norm':
        exp_string += 'batchnorm'
    elif FLAGS.norm == 'layer_norm':
        exp_string += 'layernorm'
    elif FLAGS.norm == 'None':
        exp_string += 'nonorm'
    else:
        print('Norm setting not recognized.')

    resume_itr = 0
    model_file = None

    # Initialize all variables, and
    tf.global_variables_initializer().run()
    # starts threads for all queue runners collected in the graph
    tf.train.start_queue_runners()

    if FLAGS.resume or not FLAGS.train:
        model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' +
                                                exp_string)
        if FLAGS.test_iter > 0:
            model_file = model_file[:model_file.index('model'
                                                      )] + 'model' + str(
                                                          FLAGS.test_iter)
        if model_file:
            ind1 = model_file.index('model')
            resume_itr = int(model_file[ind1 + 5:])
            print("Restoring model weights from " + model_file)
            saver.restore(sess, model_file)

    if FLAGS.train:
        train(model, saver, sess, exp_string, data_generator, resume_itr)
    else:
        test(model, saver, sess, exp_string, data_generator, test_num_updates)
コード例 #3
0
def main():
    if not os.path.exists(FLAGS.logdir):
        os.makedirs(FLAGS.logdir, exist_ok=True)

    test_num_updates = 10

    if FLAGS.train == False:
        orig_meta_batch_size = FLAGS.meta_batch_size
        # always use meta batch size of 1 when testing.
        FLAGS.meta_batch_size = 1

    data_generator = DataGenerator(FLAGS.update_batch_size * 2,
                                   FLAGS.meta_batch_size)

    dim_output = data_generator.dim_output
    dim_input = data_generator.dim_input

    tf_data_load = True
    num_classes = data_generator.num_classes

    if FLAGS.train:  # only construct training model if needed
        random.seed(5)
        image_tensor, label_tensor = data_generator.make_data_tensor()
        inputa = tf.slice(image_tensor, [0, 0, 0],
                          [-1, num_classes * FLAGS.update_batch_size, -1])
        inputb = tf.slice(image_tensor,
                          [0, num_classes * FLAGS.update_batch_size, 0],
                          [-1, -1, -1])
        labela = tf.slice(label_tensor, [0, 0, 0],
                          [-1, num_classes * FLAGS.update_batch_size, -1])
        labelb = tf.slice(label_tensor,
                          [0, num_classes * FLAGS.update_batch_size, 0],
                          [-1, -1, -1])
        input_tensors = {
            'inputa': inputa,
            'inputb': inputb,
            'labela': labela,
            'labelb': labelb
        }

    random.seed(6)
    image_tensor, label_tensor = data_generator.make_data_tensor(train=False)
    inputa = tf.slice(image_tensor, [0, 0, 0],
                      [-1, num_classes * FLAGS.update_batch_size, -1])
    inputb = tf.slice(image_tensor,
                      [0, num_classes * FLAGS.update_batch_size, 0],
                      [-1, -1, -1])
    labela = tf.slice(label_tensor, [0, 0, 0],
                      [-1, num_classes * FLAGS.update_batch_size, -1])
    labelb = tf.slice(label_tensor,
                      [0, num_classes * FLAGS.update_batch_size, 0],
                      [-1, -1, -1])
    metaval_input_tensors = {
        'inputa': inputa,
        'inputb': inputb,
        'labela': labela,
        'labelb': labelb
    }

    model = MAML(dim_input, dim_output, test_num_updates=test_num_updates)
    if FLAGS.train or not tf_data_load:
        model.construct_model(input_tensors=input_tensors, prefix='metatrain_')
    if tf_data_load:
        model.construct_model(input_tensors=metaval_input_tensors,
                              prefix='metaval_')
    model.summ_op = tf.summary.merge_all()

    saver = loader = tf.train.Saver(tf.get_collection(
        tf.GraphKeys.TRAINABLE_VARIABLES),
                                    max_to_keep=10)

    sess = tf.InteractiveSession()

    if FLAGS.train == False:
        # change to original meta batch size when loading model.
        FLAGS.meta_batch_size = orig_meta_batch_size

    if FLAGS.train_update_batch_size == -1:
        FLAGS.train_update_batch_size = FLAGS.update_batch_size
    if FLAGS.train_update_lr == -1:
        FLAGS.train_update_lr = FLAGS.update_lr

    exp_string = 'model_{}'.format(FLAGS.model_num)

    resume_itr = 0
    model_file = None

    tf.global_variables_initializer().run()
    tf.train.start_queue_runners()

    if FLAGS.resume or not FLAGS.train:
        model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' +
                                                exp_string)
        if FLAGS.test_iter > 0:
            model_file = model_file[:model_file.index('model'
                                                      )] + 'model' + str(
                                                          FLAGS.test_iter)
        if model_file:
            ind1 = model_file.index('model')
            resume_itr = int(model_file[ind1 + 5:])
            print("Restoring model weights from " + model_file)
            saver.restore(sess, model_file)

    FLAGS.train = True
    train(model, saver, sess, exp_string, data_generator, resume_itr)
    FLAGS.train = False
    test(model, saver, sess, exp_string, data_generator, test_num_updates)
コード例 #4
0
def main():
    if FLAGS.datasource == 'sinusoid':
        if FLAGS.train:
            test_num_updates = 5
        else:
            test_num_updates = 10
    else:
        if FLAGS.datasource == 'miniimagenet':
            if FLAGS.train == True:
                test_num_updates = 1  # eval on at least one update during training
            else:
                test_num_updates = 10
        else:
            test_num_updates = 10

    if FLAGS.train == False:
        orig_meta_batch_size = FLAGS.meta_batch_size
        # always use meta batch size of 1 when testing.
        FLAGS.meta_batch_size = 1

    if FLAGS.datasource == 'sinusoid':
        data_generator = DataGenerator(FLAGS.update_batch_size*2, FLAGS.meta_batch_size)
    else:
        if FLAGS.metatrain_iterations == 0 and FLAGS.datasource == 'miniimagenet':
            assert FLAGS.meta_batch_size == 1
            assert FLAGS.update_batch_size == 1
            data_generator = DataGenerator(1, FLAGS.meta_batch_size)  # only use one datapoint,
        else:
            if FLAGS.datasource == 'miniimagenet': # TODO - use 15 val examples for imagenet?
                if FLAGS.train:
                    data_generator = DataGenerator(FLAGS.update_batch_size+15, FLAGS.meta_batch_size)  # only use one datapoint for testing to save memory
                else:
                    data_generator = DataGenerator(FLAGS.update_batch_size*2, FLAGS.meta_batch_size)  # only use one datapoint for testing to save memory
            else:
                data_generator = DataGenerator(FLAGS.update_batch_size*2, FLAGS.meta_batch_size)  # only use one datapoint for testing to save memory


    dim_output = data_generator.dim_output
    if FLAGS.baseline == 'oracle':
        assert FLAGS.datasource == 'sinusoid'
        dim_input = 3
        FLAGS.pretrain_iterations += FLAGS.metatrain_iterations
        FLAGS.metatrain_iterations = 0
    else:
        dim_input = data_generator.dim_input

    if FLAGS.datasource == 'miniimagenet' or FLAGS.datasource == 'omniglot':
        tf_data_load = True
        num_classes = data_generator.num_classes

        if FLAGS.train: # only construct training model if needed
            random.seed(5)
            image_tensor, label_tensor = data_generator.make_data_tensor()
            inputa = tf.slice(image_tensor, [0,0,0], [-1,num_classes*FLAGS.update_batch_size, -1])
            inputb = tf.slice(image_tensor, [0,num_classes*FLAGS.update_batch_size, 0], [-1,-1,-1])
            labela = tf.slice(label_tensor, [0,0,0], [-1,num_classes*FLAGS.update_batch_size, -1])
            labelb = tf.slice(label_tensor, [0,num_classes*FLAGS.update_batch_size, 0], [-1,-1,-1])
            input_tensors = {'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb}

        random.seed(6)
        image_tensor, label_tensor = data_generator.make_data_tensor(train=False)
        inputa = tf.slice(image_tensor, [0,0,0], [-1,num_classes*FLAGS.update_batch_size, -1])
        inputb = tf.slice(image_tensor, [0,num_classes*FLAGS.update_batch_size, 0], [-1,-1,-1])
        labela = tf.slice(label_tensor, [0,0,0], [-1,num_classes*FLAGS.update_batch_size, -1])
        labelb = tf.slice(label_tensor, [0,num_classes*FLAGS.update_batch_size, 0], [-1,-1,-1])
        metaval_input_tensors = {'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb}
    else:
        tf_data_load = False
        input_tensors = None

    model = MAML(dim_input, dim_output, test_num_updates=test_num_updates)
    if FLAGS.train or not tf_data_load:
        model.construct_model(input_tensors=input_tensors, prefix='metatrain_')
    if tf_data_load:
        model.construct_model(input_tensors=metaval_input_tensors, prefix='metaval_')
    model.summ_op = tf.summary.merge_all()

    saver = loader = tf.train.Saver(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES), max_to_keep=10)

    sess = tf.InteractiveSession()

    if FLAGS.train == False:
        # change to original meta batch size when loading model.
        FLAGS.meta_batch_size = orig_meta_batch_size

    if FLAGS.train_update_batch_size == -1:
        FLAGS.train_update_batch_size = FLAGS.update_batch_size
    if FLAGS.train_update_lr == -1:
        FLAGS.train_update_lr = FLAGS.update_lr

    exp_string = 'cls_'+str(FLAGS.num_classes)+'.mbs_'+str(FLAGS.meta_batch_size) + '.ubs_' + str(FLAGS.train_update_batch_size) + '.numstep' + str(FLAGS.num_updates) + '.updatelr' + str(FLAGS.train_update_lr)

    if FLAGS.num_filters != 64:
        exp_string += 'hidden' + str(FLAGS.num_filters)
    if FLAGS.max_pool:
        exp_string += 'maxpool'
    if FLAGS.stop_grad:
        exp_string += 'stopgrad'
    if FLAGS.baseline:
        exp_string += FLAGS.baseline
    if FLAGS.norm == 'batch_norm':
        exp_string += 'batchnorm'
    elif FLAGS.norm == 'layer_norm':
        exp_string += 'layernorm'
    elif FLAGS.norm == 'None':
        exp_string += 'nonorm'
    else:
        print('Norm setting not recognized.')

    resume_itr = 0
    model_file = None

    tf.global_variables_initializer().run()
    tf.train.start_queue_runners()

    if FLAGS.resume or not FLAGS.train:
        model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' + exp_string)
        if FLAGS.test_iter > 0:
            model_file = model_file[:model_file.index('model')] + 'model' + str(FLAGS.test_iter)
        if model_file:
            ind1 = model_file.index('model')
            resume_itr = int(model_file[ind1+5:])
            print("Restoring model weights from " + model_file)
            saver.restore(sess, model_file)

    if FLAGS.train:
        train(model, saver, sess, exp_string, data_generator, resume_itr)
    else:
        test(model, saver, sess, exp_string, data_generator, test_num_updates)
コード例 #5
0
ファイル: main.py プロジェクト: siavash-khodadadeh/maml
def main():
    if FLAGS.datasource == 'sinusoid':
        if FLAGS.train:
            test_num_updates = 5
        else:
            test_num_updates = 10
    else:
        if FLAGS.datasource == 'miniimagenet':
            if FLAGS.train == True:
                test_num_updates = 1  # eval on at least one update during training
            else:
                test_num_updates = TEST_NUM_UPDATES
        else:
            test_num_updates = 10

    if FLAGS.train == False:
        orig_meta_batch_size = FLAGS.meta_batch_size
        # always use meta batch size of 1 when testing.
        FLAGS.meta_batch_size = 1

    if FLAGS.datasource == 'sinusoid':
        data_generator = DataGenerator(FLAGS.update_batch_size * 2,
                                       FLAGS.meta_batch_size)
    else:
        if FLAGS.metatrain_iterations == 0 and FLAGS.datasource == 'miniimagenet':
            assert FLAGS.meta_batch_size == 1
            assert FLAGS.update_batch_size == 1
            data_generator = DataGenerator(
                1, FLAGS.meta_batch_size)  # only use one datapoint,
        else:
            if FLAGS.datasource == 'miniimagenet':  # TODO - use 15 val examples for imagenet?
                if FLAGS.train:
                    # data_generator = DataGenerator(FLAGS.update_batch_size+15, FLAGS.meta_batch_size)  # only use one datapoint for testing to save memory
                    data_generator = DataGenerator(
                        FLAGS.update_batch_size * 2, FLAGS.meta_batch_size
                    )  # only use one datapoint for testing to save memory
                else:
                    data_generator = DataGenerator(
                        FLAGS.update_batch_size * 2, FLAGS.meta_batch_size
                    )  # only use one datapoint for testing to save memory
            else:
                data_generator = DataGenerator(
                    FLAGS.update_batch_size * 2, FLAGS.meta_batch_size
                )  # only use one datapoint for testing to save memory

    dim_output = data_generator.dim_output
    if FLAGS.baseline == 'oracle':
        assert FLAGS.datasource == 'sinusoid'
        dim_input = 3
        FLAGS.pretrain_iterations += FLAGS.metatrain_iterations
        FLAGS.metatrain_iterations = 0
    else:
        dim_input = data_generator.dim_input

    if FLAGS.datasource == 'miniimagenet' or FLAGS.datasource == 'omniglot':
        tf_data_load = True
        num_classes = data_generator.num_classes

        if FLAGS.train:  # only construct training model if needed
            random.seed(5)
            image_tensor, label_tensor = data_generator.make_data_tensor()
            inputa = tf.slice(image_tensor, [0, 0, 0],
                              [-1, num_classes * FLAGS.update_batch_size, -1])
            inputb = tf.slice(image_tensor,
                              [0, num_classes * FLAGS.update_batch_size, 0],
                              [-1, -1, -1])
            labela = tf.slice(label_tensor, [0, 0, 0],
                              [-1, num_classes * FLAGS.update_batch_size, -1])
            labelb = tf.slice(label_tensor,
                              [0, num_classes * FLAGS.update_batch_size, 0],
                              [-1, -1, -1])

            import tensorflow_hub as hub
            augmentation_module = hub.Module(
                'https://tfhub.dev/google/image_augmentation/nas_cifar/1',
                name='am1')

            augmentation_module2 = hub.Module(
                'https://tfhub.dev/google/image_augmentation/flipx_crop_rotate_color/1',
                name='am2')

            meta_batch_size = inputa.get_shape()[0]
            dim = inputa.get_shape()[1]

            inputb = tf.reshape(inputa, (meta_batch_size, dim, 84, 84, 3))
            result = list()
            for i in range(meta_batch_size):
                images = augmentation_module(
                    {
                        'images': inputb[i, ...],
                        'image_size': (84, 84),
                        'augmentation': True,
                    },
                    signature='from_decoded_images')

                images = augmentation_module2(
                    {
                        'images': images,
                        'image_size': (84, 84),
                        'augmentation': True,
                    },
                    signature='from_decoded_images')

                transforms = [
                    1, 0, -tf.random.uniform(
                        shape=(), minval=-20, maxval=20, dtype=tf.int32), 0, 1,
                    -tf.random.uniform(
                        shape=(), minval=-20, maxval=20, dtype=tf.int32), 0, 0
                ]
                images = tf.contrib.image.transform(images, transforms)
                result.append(images)

            inputb = tf.stack(result)

            inputb = tf.reshape(inputb, (meta_batch_size, dim, 84 * 84 * 3))
            labelb = labela

            if FLAGS.train:
                input_tensors = {
                    'inputa': inputb,
                    'inputb': inputa,
                    'labela': labela,
                    'labelb': labelb
                }
            else:
                input_tensors = {
                    'inputa': inputa,
                    'inputb': inputb,
                    'labela': labela,
                    'labelb': labelb
                }

        random.seed(6)
        image_tensor, label_tensor = data_generator.make_data_tensor(
            train=False)
        inputa = tf.slice(image_tensor, [0, 0, 0],
                          [-1, num_classes * FLAGS.update_batch_size, -1])
        inputb = tf.slice(image_tensor,
                          [0, num_classes * FLAGS.update_batch_size, 0],
                          [-1, -1, -1])
        labela = tf.slice(label_tensor, [0, 0, 0],
                          [-1, num_classes * FLAGS.update_batch_size, -1])
        labelb = tf.slice(label_tensor,
                          [0, num_classes * FLAGS.update_batch_size, 0],
                          [-1, -1, -1])
        metaval_input_tensors = {
            'inputa': inputa,
            'inputb': inputb,
            'labela': labela,
            'labelb': labelb
        }
    else:
        tf_data_load = False
        input_tensors = None

    model = MAML(dim_input, dim_output, test_num_updates=test_num_updates)
    if FLAGS.train or not tf_data_load:
        model.construct_model(input_tensors=input_tensors, prefix='metatrain_')
    if tf_data_load:
        model.construct_model(input_tensors=metaval_input_tensors,
                              prefix='metaval_')
    model.summ_op = tf.summary.merge_all()

    saver = loader = tf.train.Saver(tf.get_collection(
        tf.GraphKeys.TRAINABLE_VARIABLES),
                                    max_to_keep=10)

    sess = tf.InteractiveSession()

    if FLAGS.train == False:
        # change to original meta batch size when loading model.
        FLAGS.meta_batch_size = orig_meta_batch_size

    if FLAGS.train_update_batch_size == -1:
        FLAGS.train_update_batch_size = FLAGS.update_batch_size
    if FLAGS.train_update_lr == -1:
        FLAGS.train_update_lr = FLAGS.update_lr

    exp_string = 'cls_' + str(FLAGS.num_classes) + '.mbs_' + str(
        FLAGS.meta_batch_size) + '.ubs_' + str(
            FLAGS.train_update_batch_size) + '.numstep' + str(
                FLAGS.num_updates) + '.updatelr' + str(FLAGS.train_update_lr)

    if FLAGS.num_filters != 64:
        exp_string += 'hidden' + str(FLAGS.num_filters)
    if FLAGS.max_pool:
        exp_string += 'maxpool'
    if FLAGS.stop_grad:
        exp_string += 'stopgrad'
    if FLAGS.baseline:
        exp_string += FLAGS.baseline
    if FLAGS.norm == 'batch_norm':
        exp_string += 'batchnorm'
    elif FLAGS.norm == 'layer_norm':
        exp_string += 'layernorm'
    elif FLAGS.norm == 'None':
        exp_string += 'nonorm'
    else:
        print('Norm setting not recognized.')

    resume_itr = 0
    model_file = None

    tf.global_variables_initializer().run()
    tf.train.start_queue_runners()

    if FLAGS.resume or not FLAGS.train:
        model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' +
                                                exp_string)
        if FLAGS.test_iter > 0:
            model_file = model_file[:model_file.index('model'
                                                      )] + 'model' + str(
                                                          FLAGS.test_iter)
        if model_file:
            ind1 = model_file.index('model')
            resume_itr = int(model_file[ind1 + 5:])
            print("Restoring model weights from " + model_file)
            saver.restore(sess, model_file)

    if FLAGS.train:
        train(model, saver, sess, exp_string, data_generator, resume_itr)
    else:
        test(model, saver, sess, exp_string, data_generator, test_num_updates)
コード例 #6
0
ファイル: test.py プロジェクト: augustdemi/demi
def main():
    os.environ["CUDA_VISIBLE_DEVICES"] = FLAGS.gpu
    TOTAL_NUM_AU = 8
    all_au = ['au1', 'au2', 'au4', 'au6', 'au9', 'au12', 'au25', 'au26']

    if not FLAGS.train:
        orig_meta_batch_size = FLAGS.meta_batch_size
        # always use meta batch size of 1 when testing.
        FLAGS.meta_batch_size = 1
        temp_kshot = FLAGS.update_batch_size
        FLAGS.update_batch_size = 1
    if FLAGS.model.startswith('m2'):
        temp_num_updates = FLAGS.num_updates
        FLAGS.num_updates = 1



    data_generator = DataGenerator()

    dim_output = data_generator.num_classes
    dim_input = data_generator.dim_input

    inputa, inputb, labela, labelb = data_generator.make_data_tensor()
    metatrain_input_tensors = {'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb}

    model = MAML(dim_input, dim_output)
    model.construct_model(input_tensors=metatrain_input_tensors, prefix='metatrain_')
    model.summ_op = tf.summary.merge_all()

    saver = loader = tf.train.Saver(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES), max_to_keep=20)

    sess = tf.InteractiveSession()


    if not FLAGS.train:
        # change to original meta batch size when loading model.
        FLAGS.update_batch_size = temp_kshot
        FLAGS.meta_batch_size = orig_meta_batch_size

    if FLAGS.model.startswith('m2'):
        FLAGS.num_updates = temp_num_updates

    if FLAGS.train_update_batch_size == -1:
        FLAGS.train_update_batch_size = FLAGS.update_batch_size
    if FLAGS.train_update_lr == -1:
        FLAGS.train_update_lr = FLAGS.update_lr

    tf.global_variables_initializer().run()
    tf.train.start_queue_runners()

    print('initial weights: ', sess.run('model/b1:0'))
    print("========================================================================================")

    ################## Test ##################
    def _load_weight_m(trained_model_dir):
        all_au = ['au1', 'au2', 'au4', 'au6', 'au9', 'au12', 'au25', 'au26']
        if FLAGS.au_idx < TOTAL_NUM_AU: all_au = [all_au[FLAGS.au_idx]]
        w_arr = None
        b_arr = None
        for au in all_au:
            model_file = None
            print('model file dir: ', FLAGS.logdir + '/' + au + '/' + trained_model_dir)
            model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' + au + '/' + trained_model_dir)
            print("model_file from ", au, ": ", model_file)
            if (model_file == None):
                print(
                    "############################################################################################")
                print("####################################################################### None for ", au)
                print(
                    "############################################################################################")
            else:
                if FLAGS.test_iter > 0:
                    files = os.listdir(model_file[:model_file.index('model')])
                    if 'model' + str(FLAGS.test_iter) + '.index' in files:
                        model_file = model_file[:model_file.index('model')] + 'model' + str(FLAGS.test_iter)
                        print("model_file by test_iter > 0: ", model_file)
                    else:
                        print(" >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>", files)
                print("Restoring model weights from " + model_file)

                saver.restore(sess, model_file)
                w = sess.run('model/w1:0')
                b = sess.run('model/b1:0')
                print("updated weights from ckpt: ", b)
                print('----------------------------------------------------------')
                if w_arr is None:
                    w_arr = w
                    b_arr = b
                else:
                    w_arr = np.hstack((w_arr, w))
                    b_arr = np.vstack((b_arr, b))

        return w_arr, b_arr

    def _load_weight_s(sbjt_start_idx):
        batch_size = 10
        # 모든 au 를 이용하여 한 모델을 만든경우 그 한 모델만 로드하면됨.
        if FLAGS.model.startswith('s1'):
            three_layers = feature_layer(batch_size, TOTAL_NUM_AU)
            three_layers.loadWeight(FLAGS.vae_model_to_test, FLAGS.au_idx, num_au_for_rm=TOTAL_NUM_AU)
        # 각 au별로 다른 모델인 경우 au별 weight을 쌓아줘야함
        else:
            three_layers = feature_layer(batch_size, 1)
            all_au = ['au1', 'au2', 'au4', 'au6', 'au9', 'au12', 'au25', 'au26']
            if FLAGS.au_idx < TOTAL_NUM_AU: all_au = [all_au[FLAGS.au_idx]]
            w_arr = None
            b_arr = None
            for au in all_au:
                if FLAGS.model.startswith('s3'):
                    load_model_path = FLAGS.vae_model_to_test + '/' + FLAGS.model + '_' + au + '_kshot' + str(
                        FLAGS.update_batch_size) + '_iter100'
                elif FLAGS.model.startswith('s4'):
                    load_model_path = FLAGS.vae_model_to_test + '/' + FLAGS.model + '_' + au + '_subject' + str(
                        sbjt_start_idx + 1) + '_kshot' + str(
                        FLAGS.update_batch_size) + '_iter10_maml_adad' + str(FLAGS.test_iter)
                else:
                    load_model_path = FLAGS.vae_model_to_test + '/' + FLAGS.model + '_' + au + '_kshot' + str(
                        FLAGS.update_batch_size) + '_iter200_kshot10_iter10_nobatch_adam_noinit'
                three_layers.loadWeight(load_model_path, au)
                print('=============== Model S loaded from ', load_model_path)
                w = three_layers.model_intensity.layers[-1].get_weights()[0]
                b = three_layers.model_intensity.layers[-1].get_weights()[1]
                print('----------------------------------------------------------')
                if w_arr is None:
                    w_arr = w
                    b_arr = b
                else:
                    w_arr = np.hstack((w_arr, w))
                    b_arr = np.vstack((b_arr, b))

        return w_arr, b_arr



    def _load_weight_m0(trained_model_dir):
        model_file = None
        print('--------- model file dir: ', FLAGS.logdir + trained_model_dir)
        model_file = tf.train.latest_checkpoint(FLAGS.logdir + trained_model_dir)
        print(">>>> model_file from all_aus: ", model_file)
        if (model_file == None):
            print("####################################################################### None for all_aus")
        else:
            if FLAGS.test_iter > 0:
                files = os.listdir(model_file[:model_file.index('model')])
                if 'model' + str(FLAGS.test_iter) + '.index' in files:
                    model_file = model_file[:model_file.index('model')] + 'model' + str(FLAGS.test_iter)
                    print(">>>> model_file2: ", model_file)
                else:
                    print(" >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>", files)
            print("Restoring model weights from " + model_file)
            saver.restore(sess, model_file)
            w = sess.run('model/w1:0')
            b = sess.run('model/b1:0')
            print("updated weights from ckpt: ", b)
            print('----------------------------------------------------------')
        return w, b

    print("<<<<<<<<<<<< CONCATENATE >>>>>>>>>>>>>>")
    save_path = "./logs/result/"
    y_hat = []
    y_lab = []
    if FLAGS.all_sub_model:  # 모델이 모든 subjects를 이용해 train된 경우
        print('---------------- all sub model ----------------')
        # weight load를 한번만 실행해도됨. subject별로 모델이 다르지 않기 때문
        if FLAGS.model.startswith('m'):
            trained_model_dir = '/cls_' + str(FLAGS.num_classes) + '.mbs_' + str(
                FLAGS.meta_batch_size) + '.ubs_' + str(
                FLAGS.train_update_batch_size) + '.numstep' + str(FLAGS.num_updates) + '.updatelr' + str(
                FLAGS.train_update_lr) + '.metalr' + str(FLAGS.meta_lr)
            if FLAGS.model.startswith('m0'):
                w_arr, b_arr = _load_weight_m0(trained_model_dir)
            else:
                w_arr, b_arr = _load_weight_m(trained_model_dir)  # au별로 모델이 다르게됨

        ### test per each subject and concatenate
        for i in range(FLAGS.sbjt_start_idx, FLAGS.sbjt_start_idx + FLAGS.num_test_tasks):
            if FLAGS.model.startswith('s'):
                w_arr, b_arr = _load_weight_s(i)

            result = test_each_subject(w_arr, b_arr, i)
            y_hat.append(result[0])
            y_lab.append(result[1])
            print("y_hat shape:", result[0].shape)
            print("y_lab shape:", result[1].shape)
            print(">> y_hat_all shape:", np.vstack(y_hat).shape)
            print(">> y_lab_all shape:", np.vstack(y_lab).shape)
        print_summary(np.vstack(y_hat), np.vstack(y_lab), log_dir=save_path + "/" + "test.txt")
    else:  # 모델이 각 subject 별로 train된 경우: vae와 MAML의 train_test두 경우에만 존재 가능 + local weight test의 경우
        for subj_idx in range(FLAGS.sbjt_start_idx, FLAGS.sbjt_start_idx + FLAGS.num_test_tasks):
            if FLAGS.model.startswith('s'):
                w_arr, b_arr = _load_weight_s(subj_idx)
            else:
                trained_model_dir = '/sbjt' + str(subj_idx) + '.ubs_' + str(
                    FLAGS.train_update_batch_size) + '.numstep' + str(FLAGS.num_updates) + '.updatelr' + str(
                    FLAGS.train_update_lr) + '.metalr' + str(FLAGS.meta_lr)
                w_arr, b_arr = _load_weight_m(trained_model_dir)
            result = test_each_subject(w_arr, b_arr, subj_idx)
            y_hat.append(result[0])
            y_lab.append(result[1])
            print("y_hat shape:", result[0].shape)
            print("y_lab shape:", result[1].shape)
            print(">> y_hat_all shape:", np.vstack(y_hat).shape)
            print(">> y_lab_all shape:", np.vstack(y_lab).shape)
        print_summary(np.vstack(y_hat), np.vstack(y_lab),
                      log_dir=save_path + "/test.txt")

    end_time = datetime.now()
    elapse = end_time - start_time
    print("=======================================================")
    print(">>>>>> elapse time: " + str(elapse))
    print("=======================================================")
コード例 #7
0
def run(d):
    #IPython.embed()
    config = d['config']

    ################################################################
    ###################### Load parameters #########################
    ################################################################
    previous_dynamics_model = config["previous_dynamics_model"]
    train_now = d['train_bool']

    ################################################################
    desired_shape_for_rollout = config["testing"]["desired_shape_for_rollout"]
    save_rollout_run_num = config["testing"]["save_rollout_run_num"]
    rollout_save_filename = desired_shape_for_rollout + str(
        save_rollout_run_num)

    num_steps_per_rollout = config["testing"]["num_steps_per_rollout"]
    if (desired_shape_for_rollout == "figure8"):
        num_steps_per_rollout = 400
    elif (desired_shape_for_rollout == "zigzag"):
        num_steps_per_rollout = 150
    ##############################################################

    #settings
    cheaty_training = False
    use_one_hot = False  #True
    use_camera = False  #True
    playback_mode = False

    state_representation = "exclude_x_y"  #["exclude_x_y", "all"]

    # Settings (generally, keep these to default)
    default_addrs = [b'\x00\x01']
    use_pid_mode = True
    slow_pid_mode = True
    visualize_rviz = True  #turning this off can make things go faster
    visualize_True = True
    visualize_False = False
    noise_True = True
    noise_False = False
    make_aggregated_dataset_noisy = True
    make_training_dataset_noisy = True
    perform_forwardsim_for_vis = True
    print_minimal = False
    noiseToSignal = 0
    if (make_training_dataset_noisy):
        noiseToSignal = 0.01

    # Defining datatypes
    tf_datatype = tf.float32
    np_datatype = np.float32

    # Setting motor limits
    left_min = 1200
    right_min = 1200
    left_max = 2000
    right_max = 2000
    if (use_pid_mode):
        if (slow_pid_mode):
            left_min = 2 * math.pow(2, 16) * 0.001
            right_min = 2 * math.pow(2, 16) * 0.001
            left_max = 9 * math.pow(2, 16) * 0.001
            right_max = 9 * math.pow(2, 16) * 0.001
        else:  #this hasnt been tested yet
            left_min = 4 * math.pow(2, 16) * 0.001
            right_min = 4 * math.pow(2, 16) * 0.001
            left_max = 12 * math.pow(2, 16) * 0.001
            right_max = 12 * math.pow(2, 16) * 0.001

    #vars from config

    curr_agg_iter = config['aggregation']['curr_agg_iter']
    save_dir = d['exp_name']
    print("\n\nSAVING EVERYTHING TO: ", save_dir)

    #make directories
    if not os.path.exists(save_dir + '/saved_rollouts'):
        os.makedirs(save_dir + '/saved_rollouts')
    if not os.path.exists(save_dir + '/saved_rollouts/' +
                          rollout_save_filename + '_aggIter' +
                          str(curr_agg_iter)):
        os.makedirs(save_dir + '/saved_rollouts/' + rollout_save_filename +
                    '_aggIter' + str(curr_agg_iter))

    ######################################
    ######## GET TRAINING DATA ###########
    ######################################

    print("\n\nCURR AGGREGATION ITER: ", curr_agg_iter)

    # Training data
    # Random
    dataX = []
    dataX_full = [
    ]  #this is just for your personal use for forwardsim (for debugging)
    dataY = []
    dataZ = []

    # Training data
    # MPC
    dataX_onPol = []
    dataX_full_onPol = []
    dataY_onPol = []
    dataZ_onPol = []

    # Validation data
    # Random
    dataX_val = []
    dataX_full_val = []
    dataY_val = []
    dataZ_val = []

    # Validation data
    # MPC
    dataX_val_onPol = []
    dataX_full_val_onPol = []
    dataY_val_onPol = []
    dataZ_val_onPol = []

    training_ratio = config['training']['training_ratio']
    for agg_itr_counter in range(curr_agg_iter + 1):

        #getDataFromDisk should give (tasks, rollouts from that task, each rollout has its points)
        dataX_curr, dataY_curr, dataZ_curr, dataX_curr_full = getDataFromDisk(
            config['experiment_type'],
            use_one_hot,
            use_camera,
            cheaty_training,
            state_representation,
            agg_itr_counter,
            config_training=config['training'])

        if (agg_itr_counter == 1):
            print("*********TRYING TO FIND THE WEIRD ROLLOUT...")
            for rollout in range(len(dataX_curr[2])):
                val = dataX_curr[2][rollout][:, 4]
                if (np.any(val < 0)):
                    dataX_curr[2][rollout] = dataX_curr[2][rollout + 1]
                    dataY_curr[2][rollout] = dataY_curr[2][rollout + 1]
                    dataZ_curr[2][rollout] = dataZ_curr[2][rollout + 1]
                    print("FOUND IT!!!!!!! rollout number ", rollout)

        #random data
        #go from dataX_curr (tasks, rollouts, steps) --> to dataX (tasks, some rollouts, steps) and dataX_val (tasks, some rollouts, steps)
        if (agg_itr_counter == 0):
            for task_num in range(len(dataX_curr)):
                taski_num_rollout = len(dataX_curr[task_num])
                print("task" + str(task_num) + "_num_rollouts: ",
                      taski_num_rollout)

                #for each task, append something like (356, 48, 22) (numrollouts per task, num steps in that rollout, dim)
                dataX.append(dataX_curr[task_num][:int(taski_num_rollout *
                                                       training_ratio)])
                dataX_full.append(
                    dataX_curr_full[task_num][:int(taski_num_rollout *
                                                   training_ratio)])
                dataY.append(dataY_curr[task_num][:int(taski_num_rollout *
                                                       training_ratio)])
                dataZ.append(dataZ_curr[task_num][:int(taski_num_rollout *
                                                       training_ratio)])

                dataX_val.append(dataX_curr[task_num][int(taski_num_rollout *
                                                          training_ratio):])
                dataX_full_val.append(
                    dataX_curr_full[task_num][int(taski_num_rollout *
                                                  training_ratio):])
                dataY_val.append(dataY_curr[task_num][int(taski_num_rollout *
                                                          training_ratio):])
                dataZ_val.append(dataZ_curr[task_num][int(taski_num_rollout *
                                                          training_ratio):])

        #on-policy data
        #go from dataX_curr (tasks, rollouts, steps) --> to dataX_onPol (tasks, some rollouts, steps) and dataX_val_onPol (tasks, some rollouts, steps)
        elif (agg_itr_counter == 1):

            for task_num in range(len(dataX_curr)):
                taski_num_rollout = len(dataX_curr[task_num])
                print("task" + str(task_num) + "_num_rollouts for onpolicy: ",
                      taski_num_rollout)

                dataX_onPol.append(
                    dataX_curr[task_num][:int(taski_num_rollout *
                                              training_ratio)])
                dataX_full_onPol.append(
                    dataX_curr_full[task_num][:int(taski_num_rollout *
                                                   training_ratio)])
                dataY_onPol.append(
                    dataY_curr[task_num][:int(taski_num_rollout *
                                              training_ratio)])
                dataZ_onPol.append(
                    dataZ_curr[task_num][:int(taski_num_rollout *
                                              training_ratio)])

                dataX_val_onPol.append(
                    dataX_curr[task_num][int(taski_num_rollout *
                                             training_ratio):])
                dataX_full_val_onPol.append(
                    dataX_curr_full[task_num][int(taski_num_rollout *
                                                  training_ratio):])
                dataY_val_onPol.append(
                    dataY_curr[task_num][int(taski_num_rollout *
                                             training_ratio):])
                dataZ_val_onPol.append(
                    dataZ_curr[task_num][int(taski_num_rollout *
                                             training_ratio):])

        #on-policy data
        #go from dataX_curr (tasks, rollouts, steps) --> to ADDING ONTO dataX_onPol (tasks, some more rollouts than before, steps) and dataX_val_onPol (tasks, some more rollouts than before, steps)
        else:
            for task_num in range(len(dataX_curr)):

                taski_num_rollout = len(dataX_curr[task_num])
                print("task" + str(task_num) + "_num_rollouts for onpolicy: ",
                      taski_num_rollout)

                dataX_onPol[task_num].extend(
                    dataX_curr[task_num][:int(taski_num_rollout *
                                              training_ratio)])
                dataX_full_onPol[task_num].extend(
                    dataX_curr_full[task_num][:int(taski_num_rollout *
                                                   training_ratio)])
                dataY_onPol[task_num].extend(
                    dataY_curr[task_num][:int(taski_num_rollout *
                                              training_ratio)])
                dataZ_onPol[task_num].extend(
                    dataZ_curr[task_num][:int(taski_num_rollout *
                                              training_ratio)])

                dataX_val_onPol[task_num].extend(
                    dataX_curr[task_num][int(taski_num_rollout *
                                             training_ratio):])
                dataX_full_val_onPol[task_num].extend(
                    dataX_curr_full[task_num][int(taski_num_rollout *
                                                  training_ratio):])
                dataY_val_onPol[task_num].extend(
                    dataY_curr[task_num][int(taski_num_rollout *
                                             training_ratio):])
                dataZ_val_onPol[task_num].extend(
                    dataZ_curr[task_num][int(taski_num_rollout *
                                             training_ratio):])

    #############################################################

    #count number of random and onpol data points
    total_random_data = len(dataX) * len(dataX[1]) * len(
        dataX[1][0])  # numSteps = tasks * rollouts * steps
    if (len(dataX_onPol) == 0):
        total_onPol_data = 0
    else:
        total_onPol_data = len(dataX_onPol) * len(dataX_onPol[0]) * len(
            dataX_onPol[0][0]
        )  #this is approximate because each task doesn't have the same num rollouts or the same num steps
    total_num_data = total_random_data + total_onPol_data
    print()
    print()
    print("Number of random data points: ", total_random_data)
    print("Number of on-policy data points: ", total_onPol_data)
    print("TOTAL number of data points: ", total_num_data)

    #############################################################

    #combine random and onpol data into a single dataset for training
    ratio_new = config["aggregation"]["ratio_new"]
    num_new_pts = ratio_new * (total_random_data) / (1.0 - ratio_new)
    if (len(dataX_onPol) == 0):
        num_times_to_copy_onPol = 0
    else:
        num_times_to_copy_onPol = int(num_new_pts / total_onPol_data)

    #copy all rollouts from each task of onpol data, and do this copying this many times
    for i in range(num_times_to_copy_onPol):
        for task_num in range(len(dataX_onPol)):
            for rollout_num in range(len(dataX_onPol[task_num])):
                dataX[task_num].append(dataX_onPol[task_num][rollout_num])
                dataX_full[task_num].append(
                    dataX_full_onPol[task_num][rollout_num])
                dataY[task_num].append(dataY_onPol[task_num][rollout_num])
                dataZ[task_num].append(dataZ_onPol[task_num][rollout_num])
    #print("num_times_to_copy_onPol: ", num_times_to_copy_onPol)

    # make a list of all X,Y,Z so can take mean of them
    # concatenate state and action --> inputs (for training)
    all_points_inp = []
    all_points_outp = []
    outputs = copy.deepcopy(dataZ)
    inputs = copy.deepcopy(dataX)
    for task_num in range(len(dataX)):
        for rollout_num in range(len(dataX[task_num])):

            #this will just be a big list of everything, so can take the mean
            input_pts = np.concatenate(
                (dataX[task_num][rollout_num], dataY[task_num][rollout_num]),
                axis=1)
            output_pts = dataZ[task_num][rollout_num]

            #this will the concatenate thing for later
            inputs[task_num][rollout_num] = np.concatenate(
                [dataX[task_num][rollout_num], dataY[task_num][rollout_num]],
                axis=1)

            all_points_inp.append(input_pts)
            all_points_outp.append(output_pts)
    all_points_inp = np.concatenate(all_points_inp)
    all_points_outp = np.concatenate(all_points_outp)

    ## concatenate state and action --> inputs (for validation)
    outputs_val = copy.deepcopy(dataZ_val)
    inputs_val = copy.deepcopy(dataX_val)
    for task_num in range(len(dataX_val)):
        for rollout_num in range(len(dataX_val[task_num])):
            #dataX[task_num][rollout_num] (steps x s_dim)
            #dataY[task_num][rollout_num] (steps x a_dim)
            inputs_val[task_num][rollout_num] = np.concatenate([
                dataX_val[task_num][rollout_num],
                dataY_val[task_num][rollout_num]
            ],
                                                               axis=1)

    ## concatenate state and action --> inputs (for validation onpol)
    outputs_val_onPol = copy.deepcopy(dataZ_val_onPol)
    inputs_val_onPol = copy.deepcopy(dataX_val_onPol)
    for task_num in range(len(dataX_val_onPol)):
        for rollout_num in range(len(dataX_val_onPol[task_num])):
            #dataX[task_num][rollout_num] (steps x s_dim)
            #dataY[task_num][rollout_num] (steps x a_dim)
            inputs_val_onPol[task_num][rollout_num] = np.concatenate([
                dataX_val_onPol[task_num][rollout_num],
                dataY_val_onPol[task_num][rollout_num]
            ],
                                                                     axis=1)

    #############################################################

    #inputs should now be (tasks, rollouts from that task, [s,a])
    #outputs should now be (tasks, rollouts from that task, [ds])
    #IPython.embed()

    inputSize = inputs[0][0].shape[1]
    outputSize = outputs[1][0].shape[1]
    print("\n\nDimensions:")
    print("states: ", dataX[1][0].shape[1])
    print("actions: ", dataY[1][0].shape[1])
    print("inputs to NN: ", inputSize)
    print("outputs of NN: ", outputSize)

    mean_inp = np.expand_dims(np.mean(all_points_inp, axis=0), axis=0)
    std_inp = np.expand_dims(np.std(all_points_inp, axis=0), axis=0)
    mean_outp = np.expand_dims(np.mean(all_points_outp, axis=0), axis=0)
    std_outp = np.expand_dims(np.std(all_points_outp, axis=0), axis=0)
    print("\n\nCalulated means and stds... ", mean_inp.shape, std_inp.shape,
          mean_outp.shape, std_outp.shape, "\n\n")

    ###########################################################
    ## CREATE regressor, policy, data generator, maml model
    ###########################################################

    # create regressor (NN dynamics model)
    regressor = DeterministicMLPRegressor(
        inputSize, outputSize, outputSize, tf_datatype, config['seed'],
        config['training']['weight_initializer'], config['model'])

    # create policy (MPC controller)
    policy = Policy(regressor,
                    inputSize,
                    outputSize,
                    left_min,
                    right_min,
                    left_max,
                    right_max,
                    state_representation=state_representation,
                    visualize_rviz=visualize_rviz,
                    x_index=config['roach']['x_index'],
                    y_index=config['roach']['y_index'],
                    yaw_cos_index=config['roach']['yaw_cos_index'],
                    yaw_sin_index=config['roach']['yaw_sin_index'],
                    **config['policy'])

    # create MAML model
    # note: this also constructs the actual regressor network/weights
    model = MAML(regressor, inputSize, outputSize, config)
    model.construct_model(input_tensors=None, prefix='metatrain_')
    model.summ_op = tf.summary.merge_all()

    # GPU config proto
    gpu_device = 0
    gpu_frac = 0.4  #0.4 #0.8 #0.3
    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_device)
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_frac)
    config_2 = tf.ConfigProto(gpu_options=gpu_options,
                              log_device_placement=False,
                              allow_soft_placement=True,
                              inter_op_parallelism_threads=1,
                              intra_op_parallelism_threads=1)
    # saving
    saver = tf.train.Saver(max_to_keep=10)
    sess = tf.InteractiveSession(config=config_2)

    # initialize tensorflow vars
    tf.global_variables_initializer().run()
    tf.train.start_queue_runners()

    # set the mean/std of regressor according to mean/std of the data we have so far
    regressor.update_params_data_dist(mean_inp, std_inp, mean_outp, std_outp,
                                      total_num_data)

    ###########################################################
    ## TRAIN THE DYNAMICS MODEL
    ###########################################################

    #train on the given full dataset, for max_epochs
    if train_now:
        if config["training"]["restore_previous_dynamics_model"]:
            print("\n\nRESTORING PREVIOUS DYNAMICS MODEL FROM ",
                  previous_dynamics_model, " AND CONTINUING TRAINING...\n\n")
            saver.restore(sess, previous_dynamics_model)

        np.save(save_dir + "/inputs.npy", inputs)
        np.save(save_dir + "/outputs.npy", outputs)
        np.save(save_dir + "/inputs_val.npy", inputs_val)
        np.save(save_dir + "/outputs_val.npy", outputs_val)

        train(inputs, outputs, curr_agg_iter, model, saver, sess, config,
              inputs_val, outputs_val, inputs_val_onPol, outputs_val_onPol)
    else:
        print("\n\nRESTORING A DYNAMICS MODEL FROM ", previous_dynamics_model)
        saver.restore(sess, previous_dynamics_model)

    ###########################################################
    ## RUN THE MPC CONTROLLER
    ###########################################################

    #create controller node
    controller_node = GBAC_Controller(
        sess,
        policy,
        model,
        use_pid_mode=use_pid_mode,
        state_representation=state_representation,
        default_addrs=default_addrs,
        update_batch_size=config['testing']['update_batch_size'],
        num_updates=config['testing']['num_updates'],
        de=config['testing']['dynamic_evaluation'],
        roach_config=config['roach'])

    #do 1 rollout
    print(
        "\n\n\nPAUSING... right before a controller run... RESET THE ROBOT TO A GOOD LOCATION BEFORE CONTINUING..."
    )
    #IPython.embed()
    resulting_x, selected_u, desired_seq, list_robot_info, list_mocap_info, old_saving_format_dict, list_best_action_sequences = controller_node.run(
        num_steps_per_rollout, desired_shape_for_rollout)

    #where to save this rollout
    pathStartName = save_dir + '/saved_rollouts/' + rollout_save_filename + '_aggIter' + str(
        curr_agg_iter)
    print("\n\n************** TRYING TO SAVE EVERYTHING TO: ", pathStartName)

    #save the result of the run
    np.save(pathStartName + '/oldFormat_actions.npy',
            old_saving_format_dict['actions_taken'])
    np.save(pathStartName + '/oldFormat_desired.npy',
            old_saving_format_dict['desired_states'])
    np.save(pathStartName + '/oldFormat_executed.npy',
            old_saving_format_dict['traj_taken'])
    np.save(pathStartName + '/oldFormat_perp.npy',
            old_saving_format_dict['save_perp_dist'])
    np.save(pathStartName + '/oldFormat_forward.npy',
            old_saving_format_dict['save_forward_dist'])
    np.save(pathStartName + '/oldFormat_oldforward.npy',
            old_saving_format_dict['saved_old_forward_dist'])
    np.save(pathStartName + '/oldFormat_movedtonext.npy',
            old_saving_format_dict['save_moved_to_next'])
    np.save(pathStartName + '/oldFormat_desheading.npy',
            old_saving_format_dict['save_desired_heading'])
    np.save(pathStartName + '/oldFormat_currheading.npy',
            old_saving_format_dict['save_curr_heading'])
    np.save(pathStartName + '/list_best_action_sequences.npy',
            list_best_action_sequences)

    yaml.dump(config, open(osp.join(pathStartName, 'saved_config.yaml'), 'w'))

    #save the result of the run
    np.save(pathStartName + '/actions.npy', selected_u)
    np.save(pathStartName + '/states.npy', resulting_x)
    np.save(pathStartName + '/desired.npy', desired_seq)
    pickle.dump(list_robot_info, open(pathStartName + '/robotInfo.obj', 'w'))
    pickle.dump(list_mocap_info, open(pathStartName + '/mocapInfo.obj', 'w'))

    #stop roach
    print("killing robot")
    controller_node.kill_robot()

    return
コード例 #8
0
ファイル: main.py プロジェクト: SuperHenry2333/maml_new
def main():

    test_num_updates = 10

    if FLAGS.train == False:
        orig_meta_batch_size = FLAGS.meta_batch_size
        FLAGS.meta_batch_size = 1
    if FLAGS.train == False:
        # change to original meta batch size when loading model.
        FLAGS.meta_batch_size = orig_meta_batch_size

    if FLAGS.train_update_batch_size == -1:
        FLAGS.train_update_batch_size = FLAGS.update_batch_size
    if FLAGS.train_update_lr == -1:
        FLAGS.train_update_lr = FLAGS.update_lr
    exp_string = 'cls_' + str(FLAGS.num_classes) + '.mbs_' + str(
        FLAGS.meta_batch_size) + '.ubs_' + str(
            FLAGS.train_update_batch_size) + '.numstep' + str(
                FLAGS.num_updates) + '.updatelr' + str(
                    FLAGS.train_update_lr) + '.poison_lr' + str(
                        FLAGS.poison_lr)

    if FLAGS.num_filters != 64:
        exp_string += 'hidden' + str(FLAGS.num_filters)
    if FLAGS.max_pool:
        exp_string += 'maxpool'
    if FLAGS.baseline:
        exp_string += FLAGS.baseline
    if FLAGS.norm == 'batch_norm':
        exp_string += 'batchnorm'
    elif FLAGS.norm == 'layer_norm':
        exp_string += 'layernorm'
    elif FLAGS.norm == 'None':
        exp_string += 'nonorm'
    else:
        print('Norm setting not recognized.')

    num_images_per_class = FLAGS.update_batch_size * 3

    data_generator = DataGenerator(
        num_images_per_class, FLAGS.meta_batch_size
    )  # only use one datapoint for testing to save memory
    dim_output = data_generator.dim_output
    dim_input = data_generator.dim_input
    if FLAGS.mode == 'train_with_poison':
        print('Loading poison examples from %s' % FLAGS.poison_path)
        poison_example = np.load(FLAGS.poison_dir)
        # poison_example=np.load(FLAGS.logdir + '/' + exp_string+'/poisonx_%d.npy'%FLAGS.poison_itr)
    else:
        poison_example = None
    model = MAML(dim_input=dim_input,
                 dim_output=dim_output,
                 num_images_per_class=num_images_per_class,
                 num_classes=FLAGS.num_classes,
                 poison_example=poison_example)
    sess = tf.InteractiveSession()
    print('Session created')
    if FLAGS.datasource == 'omniglot':
        tf_data_load = True
        num_classes = data_generator.num_classes

        if FLAGS.train:  # only construct training model if needed
            random.seed(5)
            image_tensor, label_tensor = data_generator.make_data_tensor(
                train=True, poison=(model.poisonx, model.poisony), sess=sess)
            if FLAGS.reptile:
                inputa = image_tensor
                labela = label_tensor
            else:
                inputa = tf.slice(
                    image_tensor, [0, 0, 0],
                    [-1, num_classes * FLAGS.update_batch_size, -1])
                labela = tf.slice(
                    label_tensor, [0, 0, 0],
                    [-1, num_classes * FLAGS.update_batch_size, -1])
            inputb = tf.slice(image_tensor,
                              [0, num_classes * FLAGS.update_batch_size, 0],
                              [-1, -1, -1])
            labelb = tf.slice(label_tensor,
                              [0, num_classes * FLAGS.update_batch_size, 0],
                              [-1, -1, -1])
            image_tensor, label_tensor = data_generator.make_data_tensor(
                train=False)
            if FLAGS.mode == 'train_poison':
                inputa_test = tf.slice(
                    image_tensor, [0, 0, 0],
                    [-1, num_classes * FLAGS.update_batch_size, -1])
                inputb_test = tf.slice(
                    image_tensor,
                    [0, num_classes * FLAGS.update_batch_size, 0],
                    [-1, -1, -1])
                labela_test = tf.slice(
                    label_tensor, [0, 0, 0],
                    [-1, num_classes * FLAGS.update_batch_size, -1])
                labelb_test = tf.slice(
                    label_tensor,
                    [0, num_classes * FLAGS.update_batch_size, 0],
                    [-1, -1, -1])
                input_tensors = {
                    'inputa': inputa,
                    'inputb': inputb,
                    'labela': labela,
                    'labelb': labelb,
                    'inputa_test': inputa_test,
                    'inputb_test': inputb_test,
                    'labela_test': labela_test,
                    'labelb_test': labelb_test
                }
            else:
                input_tensors = {
                    'inputa': inputa,
                    'inputb': inputb,
                    'labela': labela,
                    'labelb': labelb
                }

        random.seed(6)
        image_tensor, label_tensor = data_generator.make_data_tensor(
            train=False)
        inputa = tf.slice(image_tensor, [0, 0, 0],
                          [-1, num_classes * FLAGS.update_batch_size, -1])
        inputb = tf.slice(image_tensor,
                          [0, num_classes * FLAGS.update_batch_size, 0],
                          [-1, -1, -1])
        labela = tf.slice(label_tensor, [0, 0, 0],
                          [-1, num_classes * FLAGS.update_batch_size, -1])
        labelb = tf.slice(label_tensor,
                          [0, num_classes * FLAGS.update_batch_size, 0],
                          [-1, -1, -1])
        metaval_input_tensors = {
            'inputa': inputa,
            'inputb': inputb,
            'labela': labela,
            'labelb': labelb
        }
    else:
        tf_data_load = False
        input_tensors = None

    if FLAGS.train or not tf_data_load:
        model.construct_model(input_tensors=input_tensors, prefix=FLAGS.mode)

    if tf_data_load:
        model.construct_model(input_tensors=metaval_input_tensors,
                              prefix='metaval_')
    print('Model built')
    model.summ_op = tf.summary.merge_all()
    saver = loader = tf.train.Saver(tf.get_collection(
        tf.GraphKeys.TRAINABLE_VARIABLES),
                                    max_to_keep=10)

    resume_itr = 0
    model_file = None
    tf.train.start_queue_runners()
    tf.global_variables_initializer().run()

    if FLAGS.resume or not FLAGS.train:
        model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' +
                                                exp_string)
        if FLAGS.test_iter > 0:
            model_file = model_file[:model_file.index('model'
                                                      )] + 'model' + str(
                                                          FLAGS.test_iter)
        if model_file:
            ind1 = model_file.index('model')
            resume_itr = int(model_file[ind1 + 5:])
            print("Restoring model weights from " + model_file)
            saver.restore(sess, model_file)
    test_params = [
        model, saver, sess, exp_string, data_generator, test_num_updates
    ]
    test(model, saver, sess, exp_string, data_generator, test_num_updates)
    if FLAGS.train:
        train(model,
              saver,
              sess,
              exp_string,
              data_generator,
              resume_itr,
              test_params=test_params)
    else:
        test(model, saver, sess, exp_string, data_generator, test_num_updates)
コード例 #9
0
ファイル: main.py プロジェクト: TimeLovercc/HSML_for_Graph
def main():

    if FLAGS.datasource == 'multidataset_leave_one_out':
        assert FLAGS.leave_one_out_id > -1

    sess = tf.InteractiveSession()
    if FLAGS.datasource in ['sinusoid', 'mixture']:
        if FLAGS.train:
            test_num_updates = 1
        else:
            test_num_updates = 10
    else:
        if FLAGS.datasource in ['miniimagenet', 'multidataset', 'multidataset_leave_one_out']:
            if FLAGS.train == True:
                test_num_updates = 1  # eval on at least one update during training
            else:
                test_num_updates = 10
        else:
            test_num_updates = 10

    if FLAGS.train == False:
        orig_meta_batch_size = FLAGS.meta_batch_size
        # always use meta batch size of 1 when testing.
        FLAGS.meta_batch_size = 1

    if FLAGS.datasource in ['sinusoid', 'mixture']:
        data_generator = DataGenerator(FLAGS.update_batch_size + FLAGS.update_batch_size_eval, FLAGS.meta_batch_size)
    else:
        if FLAGS.metatrain_iterations == 0 and FLAGS.datasource in ['miniimagenet', 'multidataset', 'multidataset_leave_one_out']:
            assert FLAGS.meta_batch_size == 1
            assert FLAGS.update_batch_size == 1
            data_generator = DataGenerator(1, FLAGS.meta_batch_size)  # only use one datapoint,
        else:
            if FLAGS.datasource in ['miniimagenet', 'multidataset', 'multidataset_leave_one_out']:
                if FLAGS.train:
                    data_generator = DataGenerator(FLAGS.update_batch_size + 15,
                                                   FLAGS.meta_batch_size)  # only use one datapoint for testing to save memory
                else:
                    data_generator = DataGenerator(FLAGS.update_batch_size * 2,
                                                   FLAGS.meta_batch_size)  # only use one datapoint for testing to save memory
            else:
                data_generator = DataGenerator(FLAGS.update_batch_size * 2,
                                               FLAGS.meta_batch_size)  # only use one datapoint for testing to save memory

    dim_output = data_generator.dim_output
    dim_input = data_generator.dim_input

    if FLAGS.datasource in ['miniimagenet', 'omniglot', 'multidataset', 'multidataset_leave_one_out']:
        tf_data_load = True
        num_classes = data_generator.num_classes

        if FLAGS.train:  # only construct training model if needed
            random.seed(5)
            if FLAGS.datasource in ['miniimagenet', 'omniglot']:
                image_tensor, label_tensor = data_generator.make_data_tensor()
            elif FLAGS.datasource == 'multidataset':
                image_tensor, label_tensor = data_generator.make_data_tensor_multidataset()
            elif FLAGS.datasource == 'multidataset_leave_one_out':
                image_tensor, label_tensor = data_generator.make_data_tensor_multidataset_leave_one_out()
            inputa = tf.slice(image_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1])
            inputb = tf.slice(image_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1])
            labela = tf.slice(label_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1])
            labelb = tf.slice(label_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1])
            input_tensors = {'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb}

        random.seed(6)
        if FLAGS.datasource in ['miniimagenet', 'omniglot']:
            image_tensor, label_tensor = data_generator.make_data_tensor(train=False)
        elif FLAGS.datasource == 'multidataset':
            image_tensor, label_tensor = data_generator.make_data_tensor_multidataset(train=False)
        elif FLAGS.datasource == 'multidataset_leave_one_out':
            image_tensor, label_tensor = data_generator.make_data_tensor_multidataset_leave_one_out(train=False)
        inputa = tf.slice(image_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1])
        inputb = tf.slice(image_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1])
        labela = tf.slice(label_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1])
        labelb = tf.slice(label_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1])
        metaval_input_tensors = {'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb}
    else:
        tf_data_load = False
        input_tensors = None

    model = MAML(sess, dim_input, dim_output, test_num_updates=test_num_updates)

    if FLAGS.train or not tf_data_load:
        model.construct_model(input_tensors=input_tensors, prefix='metatrain_')
    if tf_data_load:
        model.construct_model(input_tensors=metaval_input_tensors, prefix='metaval_')
    model.summ_op = tf.summary.merge_all()
    saver = loader = tf.train.Saver(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES), max_to_keep=10)

    if FLAGS.train == False:
        # change to original meta batch size when loading model.
        FLAGS.meta_batch_size = orig_meta_batch_size

    if FLAGS.train_update_batch_size == -1:
        FLAGS.train_update_batch_size = FLAGS.update_batch_size
    if FLAGS.train_update_lr == -1:
        FLAGS.train_update_lr = FLAGS.update_lr

    exp_string = 'cls_' + str(FLAGS.num_classes) + '.mbs_' + str(FLAGS.meta_batch_size) + '.ubs_' + str(
        FLAGS.train_update_batch_size) + '.numstep' + str(FLAGS.num_updates) + '.updatelr' + str(
        FLAGS.train_update_lr) + '.metalr' + str(FLAGS.meta_lr) + '.emb_loss_weight' + str(
        FLAGS.emb_loss_weight) + '.num_groups' + str(FLAGS.num_groups) + '.emb_type' + str(
        FLAGS.emb_type) + '.hidden_dim' + str(FLAGS.hidden_dim)

    if FLAGS.num_filters != 64:
        exp_string += 'hidden' + str(FLAGS.num_filters)
    if FLAGS.max_pool:
        exp_string += 'maxpool'
    if FLAGS.stop_grad:
        exp_string += 'stopgrad'
    if FLAGS.norm == 'batch_norm':
        exp_string += 'batchnorm'
    elif FLAGS.norm == 'layer_norm':
        exp_string += 'layernorm'
    elif FLAGS.norm == 'None':
        exp_string += 'nonorm'
    else:
        print('Norm setting not recognized.')

    resume_itr = 0
    model_file = None

    tf.global_variables_initializer().run()
    tf.train.start_queue_runners()

    print(exp_string)

    if FLAGS.resume or not FLAGS.train:
        if FLAGS.train == True:
            model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' + exp_string)
        else:
            print(FLAGS.test_epoch)
            model_file = '{0}/{2}/model{1}'.format(FLAGS.logdir, FLAGS.test_epoch, exp_string)
        if model_file:
            ind1 = model_file.index('model')
            resume_itr = int(model_file[ind1 + 5:])
            print("Restoring model weights from " + model_file)
            saver.restore(sess, model_file)

    if FLAGS.train:
        train(model, saver, sess, exp_string, data_generator, resume_itr)
    else:
        test(model, saver, sess, exp_string, data_generator, test_num_updates)
コード例 #10
0
def main():
    temp = FLAGS.update_batch_size
    temp2 = FLAGS.meta_batch_size

    FLAGS.update_batch_size = 1
    FLAGS.meta_batch_size = 1
    data_generator = DataGenerator(FLAGS.update_batch_size * 2,
                                   FLAGS.meta_batch_size)

    dim_output = data_generator.num_classes
    dim_input = data_generator.dim_input

    if FLAGS.train:  # only construct training model if needed

        # image_tensor, label_tensor = data_generator.make_data_tensor()
        # inputa = tf.slice(image_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) #(모든 task수, NK, 모든 dim) = (meta_batch_size, NK, 2000)
        # #여기서 NK는 N개씩 K번 쌓은것. N개씩 쌓을때 0~N-1의 라벨을 하나씩 담되 랜덤 순서로 담음.
        # inputb = tf.slice(image_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1])  #(모든 task수, NK, 모든 dim) = (meta_batch_size, NK, 2000)
        # labela = tf.slice(label_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1])  #(모든 task수, NK, 모든 label) = (meta_batch_size, NK, N)
        # labelb = tf.slice(label_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) #(모든 task수, NK, 모든 label) = (meta_batch_size, NK, N)
        inputa, inputb, labela, labelb = data_generator.make_data_tensor()
        metatrain_input_tensors = {
            'inputa': inputa,
            'inputb': inputb,
            'labela': labela,
            'labelb': labelb
        }

    inputa, inputb, labela, labelb = data_generator.make_data_tensor(
        train=False)
    metaval_input_tensors = {
        'inputa': inputa,
        'inputb': inputb,
        'labela': labela,
        'labelb': labelb
    }

    pred_weights = data_generator.pred_weights
    model = MAML(dim_input, dim_output)
    if FLAGS.train:
        model.construct_model(input_tensors=metatrain_input_tensors,
                              prefix='metatrain_')
    else:
        model.construct_model(input_tensors=metaval_input_tensors,
                              prefix='metaval_')
    model.summ_op = tf.summary.merge_all()

    saver = loader = tf.train.Saver(tf.get_collection(
        tf.GraphKeys.TRAINABLE_VARIABLES),
                                    max_to_keep=20)

    sess = tf.InteractiveSession()

    FLAGS.update_batch_size = temp
    FLAGS.meta_batch_size = temp2

    trained_model_dir = 'cls_' + str(FLAGS.num_classes) + '.mbs_' + str(FLAGS.meta_batch_size) + '.ubs_' + str(
        FLAGS.update_batch_size) + '.numstep' + str(FLAGS.num_updates) + '.updatelr' + str(
        FLAGS.update_lr) + '.metalr' + str(FLAGS.meta_lr) + '.initweight' + str(FLAGS.init_weight) + \
                        '/sbjt14:13.ubs_' + str(FLAGS.update_batch_size) +'.numstep5.updatelr0.005.metalr0.005'

    # if FLAGS.stop_grad:
    #     trained_model_dir += 'stopgrad'
    # if FLAGS.baseline:
    #     trained_model_dir += FLAGS.baseline
    # else:
    #     print('Norm setting not recognized.')

    resume_itr = 0

    tf.global_variables_initializer().run()
    tf.train.start_queue_runners()

    model_file = None
    model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' +
                                            trained_model_dir)
    w = None
    b = None
    print(">>> kshot: ", FLAGS.update_batch_size)
    print(">>>> train_test model dir: ", model_file)
    model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' +
                                            trained_model_dir)
    saver.restore(sess, model_file)
    w = sess.run('model/w1:0')
    print("global abs of w: ", np.linalg.norm(w))
    b = sess.run('model/b1:0')
    print("global abs of b: ", np.linalg.norm(b))
    model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' +
                                            trained_model_dir + '/local')
    for i in range(13):
        model_file = model_file[:model_file.index('subject'
                                                  )] + 'subject' + str(i)
        print(">>>> model_file_local: ", model_file)
        saver.restore(sess, model_file)
        w = sess.run('model/w1:0')
        print("subject ", i, ", abs of w: ", np.linalg.norm(w))
        b = sess.run('model/b1:0')
        print("subject ", i, ", abs of b: ", np.linalg.norm(b))
コード例 #11
0
def run(d):

	#restore old dynamics model
	train_now = True
	restore_previous = False
	# old_exp_name = 'MAML_roach/terrain_types_turf_model_on_turf'
	# old_model_num = 0
	previous_dynamics_model = "/home/anagabandi/rllab-private/data/local/experiment/MAML_roach_copy/Tuesday_optimization/carpet_on_carpet/model_epoch10"

	num_steps_per_rollout= 140
	desired_shape_for_rollout = "left"                     #straight, left, right, circle_left, zigzag, figure8
	save_rollout_run_num = 0
	rollout_save_filename= desired_shape_for_rollout + str(save_rollout_run_num)

	#settings
	cheaty_training = False
	use_one_hot = False #True
	use_camera = False #True
	playback_mode = False
	
	state_representation = "exclude_x_y" #["exclude_x_y", "all"]

	#don't change much
	default_addrs= [b'\x00\x01']
	use_pid_mode = True      
	slow_pid_mode = True
	visualize_rviz=True   #turning this off can make things go faster
	visualize_True = True
	visualize_False = False
	noise_True = True
	noise_False = False
	make_aggregated_dataset_noisy = True
	make_training_dataset_noisy = True
	perform_forwardsim_for_vis= True
	print_minimal=False
	noiseToSignal = 0
	if(make_training_dataset_noisy):
		noiseToSignal = 0.01

	#datatypes
	tf_datatype= tf.float32 ############################# CHANGE BACK!
	np_datatype= np.float32

	#motor limits
	left_min = 1200
	right_min = 1200
	left_max = 2000
	right_max = 2000
	if(use_pid_mode):
	  if(slow_pid_mode):
		left_min = 2*math.pow(2,16)*0.001
		right_min = 2*math.pow(2,16)*0.001
		left_max = 9*math.pow(2,16)*0.001
		right_max = 9*math.pow(2,16)*0.001
	  else: #this hasnt been tested yet
		left_min = 4*math.pow(2,16)*0.001
		right_min = 4*math.pow(2,16)*0.001
		left_max = 12*math.pow(2,16)*0.001
		right_max = 12*math.pow(2,16)*0.001

	#vars from config
	config = d['config']
	curr_agg_iter = d['curr_agg_iter']
	save_dir = '/media/anagabandi/f1e71f04-dc4b-4434-ae4c-fcb16447d5b3/' + d['exp_name']
	############################################################################################ CHANGE BACK! 
	#save_dir = '/media/anagabandi/f1e71f04-dc4b-4434-ae4c-fcb16447d5b3/' + d['exp_name']

	print("\n\nSAVING EVERYTHING TO: ", save_dir)

	#make directories
	if not os.path.exists(save_dir + '/saved_rollouts'):
		os.makedirs(save_dir + '/saved_rollouts')
	if not os.path.exists(save_dir + '/saved_rollouts/'+rollout_save_filename+ '_aggIter' +str(curr_agg_iter)):
		os.makedirs(save_dir + '/saved_rollouts/'+rollout_save_filename+ '_aggIter' +str(curr_agg_iter))

	######################################
	######## GET TRAINING DATA ###########
	######################################

	print("\n\nCURR AGGREGATION ITER: ", curr_agg_iter)
	# Training data
	dataX=[]
	dataX_full=[] #this is just for your personal use for forwardsim (for debugging)
	dataY=[]
	dataZ=[]

	# Validation data
	dataX_val = []
	dataX_full_val=[]
	dataY_val=[]
	dataZ_val=[]

	agg_itr = 0
	training_ratio = config['training']['training_ratio']
	for agg_itr in range(curr_agg_iter+1):
		#getDataFromDisk should give (tasks, rollouts from that task, each rollout has its points)
		dataX_curr, dataY_curr, dataZ_curr, dataX_curr_full = getDataFromDisk(agg_itr, config['experiment_type'], 
																			use_one_hot, use_camera, 
																			cheaty_training, state_representation, config['training'])
		if(agg_itr==0):
			for i in range(len(dataX_curr)):
				taski_num_rollout = len(dataX_curr[i])
				print("taski_num_rollout: ", taski_num_rollout)
				dataX.append(dataX_curr[0][:int(taski_num_rollout*training_ratio)])

				dataX_full.append(dataX_curr_full[0][:int(taski_num_rollout*training_ratio)])
				dataY.append(dataY_curr[0][:int(taski_num_rollout*training_ratio)])
				dataZ.append(dataZ_curr[0][:int(taski_num_rollout*training_ratio)])

				dataX_val.append(dataX_curr[0][int(taski_num_rollout*training_ratio):])
				dataX_full_val.append(dataX_curr_full[0][int(taski_num_rollout*training_ratio):])
				dataY_val.append(dataY_curr[0][int(taski_num_rollout*training_ratio):])
				dataZ_val.append(dataZ_curr[0][int(taski_num_rollout*training_ratio):])
			#IPython.embed()
		else:
			#combine these rollouts w previous rollouts, so everything is still organized by task
			for task_num in range(len(dataX)):
				for rollout_num in range(len(dataX_curr[task_num])):
					dataX[task_num].append(dataX_curr[task_num][rollout_num])
					dataY[task_num].append(dataY_curr[task_num][rollout_num])
					dataZ[task_num].append(dataZ_curr[task_num][rollout_num])
					dataX_full[task_num].append(dataX_curr_full[task_num][rollout_num])
			# Do validation for this too! 

	total_num_data = len(dataX)*len(dataX[0])*len(dataX[0][0]) # numSteps = tasks * rollouts * steps
	print("\n\nTotal number of data points: ", total_num_data)

	#return
	## concatenate state and action --> inputs
	outputs = copy.deepcopy(dataZ)
	inputs = copy.deepcopy(dataX)

	#IPython.embed()
	inputs_val = np.append(np.array(dataX_val), np.array(dataY_val), axis = 3)
	outputs_val = np.array(dataZ_val)
	#IPython.embed() # check shapes
	for task_num in range(len(dataX)):
		for rollout_num in range (len(dataX[task_num])):
			#dataX[task_num][rollout_num] (steps x s_dim)
			#dataY[task_num][rollout_num] (steps x a_dim)
			inputs[task_num][rollout_num] = np.concatenate([dataX[task_num][rollout_num], dataY[task_num][rollout_num]], axis=1)
	
	#inputs should now be (tasks, rollouts from that task, [s,a])
	#outputs should now be (tasks, rollouts from that task, [ds])
	inputSize = inputs[0][0].shape[1]
	outputSize = outputs[0][0].shape[1]
	print("\n\nDimensions:")
	print("states: ", dataX[0][0].shape[1])
	print("actions: ", dataY[0][0].shape[1])
	print("inputs to NN: ", inputSize)
	print("outputs of NN: ", outputSize)

	#calc mean/std on full dataset
	if config["model"]["nonlinearity"] == "tanh":
		# Do you scale inputs to [-1, 1] and then standardize outputs?
		#IPython.embed()
		inputs_array = np.array(inputs)
		mean_inp = (inputs_array.max() + inputs_array.min())/2.0
		std_inp = inputs_array.max() - mean_inp

		mean_inp = mean_inp*np.ones((1, inputs_array.shape[3]))
		std_inp = std_inp*np.ones((1, inputs_array.shape[3]))
		#IPython.embed()

		mean_outp = np.expand_dims(np.mean(outputs,axis=(0,1,2)), axis=0)
		std_outp = np.expand_dims(np.std(outputs,axis=(0,1,2)), axis=0)
		#IPython.embed() # HOw should I expand_dims? # check that after the operation, all inputs do lie in this range
	elif config["model"]["nonlinearity"] == "sigmoid":
		# Do you scale inputs to [0, 1] and then standardize outputs?
		#IPython.embed()
		inputs_array = np.array(inputs)
		mean_inp = inputs_array.min()
		std_inp = inputs_array.max() - mean_inp

		mean_inp = mean_inp*np.ones((1, inputs_array.shape[3]))
		std_inp = std_inp*np.ones((1, inputs_array.shape[3]))

		#IPython.embed()

		mean_outp = np.expand_dims(np.mean(outputs,axis=(0,1,2)), axis=0)
		std_outp = np.expand_dims(np.std(outputs,axis=(0,1,2)), axis=0)
		#IPython.embed() # HOw should I expand_dims? # check that after the operation, all inputs do lie in this range
	else:  # for all the relu variants
		mean_inp = np.expand_dims(np.mean(inputs,axis=(0,1,2)), axis=0)
		std_inp = np.expand_dims(np.std(inputs,axis=(0,1,2)), axis=0)
		mean_outp = np.expand_dims(np.mean(outputs,axis=(0,1,2)), axis=0)
		std_outp = np.expand_dims(np.std(outputs,axis=(0,1,2)), axis=0)
	print("\n\nCalulated means and stds... ", mean_inp.shape, std_inp.shape, mean_outp.shape, std_outp.shape, "\n\n")

	###########################################################
	## CREATE regressor, policy, data generator, maml model
	###########################################################

	# create regressor (NN dynamics model)
	regressor = DeterministicMLPRegressor(inputSize, outputSize, dim_obs=outputSize, tf_datatype=tf_datatype, seed=config['seed'],weight_initializer=config['training']['weight_initializer'], **config['model'])

	# create policy (MPC controller)
	policy = Policy(regressor, inputSize, outputSize, 
					left_min, right_min, left_max, right_max, state_representation=state_representation,
					visualize_rviz=config['roach']['visualize_rviz'], 
					x_index=config['roach']['x_index'], 
					y_index=config['roach']['y_index'], 
					yaw_cos_index=config['roach']['yaw_cos_index'],
					yaw_sin_index=config['roach']['yaw_sin_index'], 
					**config['policy'])

	# create MAML model
		# note: this also constructs the actual regressor network/weights
	model = MAML(regressor, inputSize, outputSize, config=config['training'])
	model.construct_model(input_tensors=None, prefix='metatrain_')
	model.summ_op = tf.summary.merge_all()

	# GPU config proto
	gpu_device = 0
	gpu_frac = 0.3 #0.3
	os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_device)
	gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_frac)
	config_2 = tf.ConfigProto(gpu_options=gpu_options,
							log_device_placement=False,
							allow_soft_placement=True,
							inter_op_parallelism_threads=1,
							intra_op_parallelism_threads=1)
	# saving
	saver = tf.train.Saver(max_to_keep=10)
	sess = tf.InteractiveSession(config=config_2)

	# initialize tensorflow vars
	tf.global_variables_initializer().run()
	tf.train.start_queue_runners()

	# set the mean/std of regressor according to mean/std of the data we have so far
	regressor.update_params_data_dist(mean_inp, std_inp, mean_outp, std_outp, len(inputs[0])*len(inputs[0][0])*4)

	###########################################################
	## TRAIN THE DYNAMICS MODEL
	###########################################################

	#train on the given full dataset, for max_epochs
	if train_now:
		if(restore_previous):
			print("\n\nRESTORING PREVIOUS DYNAMICS MODEL FROM ", previous_dynamics_model, " AND CONTINUING TRAINING...\n\n")
			saver.restore(sess, previous_dynamics_model)
			
			"""trainable_vars = tf.trainable_variables()
      		weights = sess.run(trainable_vars)
     		with open(osp.join(osp.dirname(previous_dynamics_model), "weights.pickle"), "wb") as output_file:
        		pickle.dump(weights, output_file)"""
		#IPython.embed()
		# np.save(save_dir + "/inputs.npy", inputs)
		# np.save(save_dir + "/outputs.npy", outputs)
		# # mean_inp.shape, std_inp.shape, mean_outp.shape, std_outp.shape
		# np.save(save_dir + "/mean_inp.npy", mean_inp)
		# np.save(save_dir + "/std_inp.npy", std_inp)
		# np.save(save_dir + "/mean_outp.npy", mean_outp)
		# np.save(save_dir + "/std_outp.npy", std_outp)
		
		train(inputs, outputs, curr_agg_iter, model, saver, sess, config, inputs_val, outputs_val)
	else: 
		print("\n\nRESTORING A DYNAMICS MODEL FROM ", previous_dynamics_model)
		saver.restore(sess, previous_dynamics_model)
		#IPython.embed()
	return
	#IPython.embed()
	predicted_traj = regressor.do_forward_sim(dataX_full[0][0][27:45], dataY[0][0][27:45], state_representation)
	#np.save(save_dir + '/forwardsim_true.npy', dataX_full[0][7][27:45])
	#np.save(save_dir + '/forwardsim_pred.npy', predicted_traj)

	###########################################################
	## RUN THE MPC CONTROLLER
	###########################################################

	#create controller node
	controller_node = GBAC_Controller(sess=sess, policy=policy, model=model,
									state_representation=state_representation, use_pid_mode=use_pid_mode, 
									default_addrs=default_addrs, update_batch_size=config['training']['update_batch_size'], **config['roach'])

	#do 1 rollout
	print("\n\n\nPAUSING... right before a controller run... RESET THE ROBOT TO A GOOD LOCATION BEFORE CONTINUING...")
	#IPython.embed()
	resulting_x, selected_u, desired_seq, list_robot_info, list_mocap_info, old_saving_format_dict = controller_node.run(num_steps_per_rollout, desired_shape_for_rollout)
	
	#where to save this rollout
	pathStartName = save_dir + '/saved_rollouts/'+rollout_save_filename+ '_aggIter' +str(curr_agg_iter)
	print("\n\n************** TRYING TO SAVE EVERYTHING TO: ", pathStartName)

	#save the result of the run
	np.save(pathStartName + '/oldFormat_actions.npy', old_saving_format_dict['actions_taken'])
	np.save(pathStartName + '/oldFormat_desired.npy', old_saving_format_dict['desired_states'])
	np.save(pathStartName + '/oldFormat_executed.npy', old_saving_format_dict['traj_taken'])
	np.save(pathStartName + '/oldFormat_perp.npy', old_saving_format_dict['save_perp_dist'])
	np.save(pathStartName + '/oldFormat_forward.npy', old_saving_format_dict['save_forward_dist'])
	np.save(pathStartName + '/oldFormat_oldforward.npy', old_saving_format_dict['saved_old_forward_dist'])
	np.save(pathStartName + '/oldFormat_movedtonext.npy', old_saving_format_dict['save_moved_to_next'])
	np.save(pathStartName + '/oldFormat_desheading.npy', old_saving_format_dict['save_desired_heading'])
	np.save(pathStartName + '/oldFormat_currheading.npy', old_saving_format_dict['save_curr_heading'])

	yaml.dump(config, open(osp.join(pathStartName, 'saved_config.yaml'), 'w'))

	#save the result of the run
	np.save(pathStartName + '/actions.npy', selected_u)
	np.save(pathStartName + '/states.npy', resulting_x)
	np.save(pathStartName + '/desired.npy', desired_seq)
	pickle.dump(list_robot_info,open(pathStartName + '/robotInfo.obj','w'))
	pickle.dump(list_mocap_info,open(pathStartName + '/mocapInfo.obj','w'))

	#stop roach
	print("killing robot")
	controller_node.kill_robot()

	return
コード例 #12
0
def main():
    # test_num_updates
    ##########################################################
    if FLAGS.datasource == 'sinusoid':
        if FLAGS.train:
            test_num_updates = 5
        else:
            test_num_updates = 10
    else:
        if FLAGS.datasource == 'miniimagenet':
            if FLAGS.train == True:
                test_num_updates = 1  # eval on at least one update during training
            else:
                test_num_updates = 10
        else:
            test_num_updates = 10
    ##########################################################

    if FLAGS.train == False:
        orig_meta_batch_size = FLAGS.meta_batch_size
        # always use meta batch size of 1 when testing.
        FLAGS.meta_batch_size = 1

    if FLAGS.datasource == 'sinusoid':
        "如果为正弦拟合任务,"
        "则构造FLAGS.meta_batch_size个正弦函数,每个正弦函数采FLAGS.update_batch_size * 2个样本"
        "默认值:FLAGS.update_batch_size=5, FLAGS.meta_batch_size=25"
        "则默认设置的正弦任务数据生成器每次产生数据的尺寸为:[25, 10, 1]"
        data_generator = DataGenerator(FLAGS.update_batch_size * 2,
                                       FLAGS.meta_batch_size)
    else:
        "如果不是正弦拟合任务,"
        if FLAGS.metatrain_iterations == 0 and FLAGS.datasource == 'miniimagenet':
            "如果meta训练的迭代轮数为0,且为miniimagenet任务"
            "断言FLAGS.meta_batch_size=1, 即判断类别数是否为1"
            assert FLAGS.meta_batch_size == 1
            "断言FLAGS.update_batch_size=1, 即判断类别下的采样数是否为1"
            assert FLAGS.update_batch_size == 1
            "构造一个类别采样一个数据的数据生成器"
            data_generator = DataGenerator(
                1, FLAGS.meta_batch_size)  # only use one datapoint,
        else:
            if FLAGS.datasource == 'miniimagenet':  # TODO - use 15 val examples for imagenet?
                "如果任务是miniimagenet,"
                if FLAGS.train:
                    "如果是训练任务,"
                    "构造一个类别数为FLAGS.meta_batch_size, 每个类别下采样FLAGS.update_batch_size + 15个样本"
                    "默认值:FLAGS.update_batch_size=5, FLAGS.meta_batch_size=25"
                    "则默认设置的miniimagenet任务数据生成器每次产生数据的尺寸为: [25, 5+15, 84x84x3]"
                    data_generator = DataGenerator(
                        FLAGS.update_batch_size + 15, FLAGS.meta_batch_size
                    )  # only use one datapoint for testing to save memory
                else:
                    "如果不是训练任务,"
                    "则默认设置的miniimagenet任务数据生成器每次产生数据的尺寸为: [25, 5*2, 84x84x3]"
                    data_generator = DataGenerator(
                        FLAGS.update_batch_size * 2, FLAGS.meta_batch_size
                    )  # only use one datapoint for testing to save memory
            else:
                "如果任务是omniglot,"
                "则默认设置的omniglot任务数据生成器每次产生数据的尺寸为: [25, 5*2, 28x28]"
                data_generator = DataGenerator(
                    FLAGS.update_batch_size * 2, FLAGS.meta_batch_size
                )  # only use one datapoint for testing to save memory
    "任务数据的输出维度, 对于:"
    "正弦拟合:dim_output=1"
    "omniglot: dim_output=data_generator.num_classes, data_generator.num_classes默认值为1"
    "miniimagenet: dim_output=data_generator.num_classes, data_generator.num_classes默认值为1"
    # dim_output
    ##############################################
    "数据的输出维度"
    dim_output = data_generator.dim_output
    # dim_input
    ##############################################
    if FLAGS.baseline == 'oracle':
        "如果FLAGS.baseline==oracle, "
        "则断言检查FLAGS.datasource是否为正弦拟合任务,否则报错"
        assert FLAGS.datasource == 'sinusoid'
        "将输入维度修改为3"
        dim_input = 3
        "将meta训练的迭代轮并入到预训练迭代轮数"
        FLAGS.pretrain_iterations += FLAGS.metatrain_iterations
        "将meta训练的迭代轮数置为0"
        FLAGS.metatrain_iterations = 0
    else:
        "正弦拟合:dim_input=1"
        "omniglot: dim_input=28x28"
        "miniimagenet: dim_input=84x84x3"
        dim_input = data_generator.dim_input

    if FLAGS.datasource == 'miniimagenet' or FLAGS.datasource == 'omniglot':
        "如果任务为miniimagenet或者omniglot,"
        "则需要构造tensorflow的数据记载相关操作对应的计算图"
        tf_data_load = True
        "类别数"
        num_classes = data_generator.num_classes
        if FLAGS.train:  # only construct training model if needed
            "如果是训练阶段,"
            "初始化随机种子,保证实验可重复"
            random.seed(5)
            image_tensor, label_tensor = data_generator.make_data_tensor()
            inputa = tf.slice(image_tensor, [0, 0, 0],
                              [-1, num_classes * FLAGS.update_batch_size, -1])
            inputb = tf.slice(image_tensor,
                              [0, num_classes * FLAGS.update_batch_size, 0],
                              [-1, -1, -1])
            labela = tf.slice(label_tensor, [0, 0, 0],
                              [-1, num_classes * FLAGS.update_batch_size, -1])
            labelb = tf.slice(label_tensor,
                              [0, num_classes * FLAGS.update_batch_size, 0],
                              [-1, -1, -1])
            input_tensors = {
                'inputa': inputa,
                'inputb': inputb,
                'labela': labela,
                'labelb': labelb
            }

        random.seed(6)
        image_tensor, label_tensor = data_generator.make_data_tensor(
            train=False)
        inputa = tf.slice(image_tensor, [0, 0, 0],
                          [-1, num_classes * FLAGS.update_batch_size, -1])
        inputb = tf.slice(image_tensor,
                          [0, num_classes * FLAGS.update_batch_size, 0],
                          [-1, -1, -1])
        labela = tf.slice(label_tensor, [0, 0, 0],
                          [-1, num_classes * FLAGS.update_batch_size, -1])
        labelb = tf.slice(label_tensor,
                          [0, num_classes * FLAGS.update_batch_size, 0],
                          [-1, -1, -1])
        metaval_input_tensors = {
            'inputa': inputa,
            'inputb': inputb,
            'labela': labela,
            'labelb': labelb
        }
    else:
        tf_data_load = False
        input_tensors = None

    model = MAML(dim_input, dim_output, test_num_updates=test_num_updates)
    if FLAGS.train or not tf_data_load:
        model.construct_model(input_tensors=input_tensors, prefix='metatrain_')
    if tf_data_load:
        model.construct_model(input_tensors=metaval_input_tensors,
                              prefix='metaval_')
    model.summ_op = tf.summary.merge_all()

    saver = loader = tf.train.Saver(tf.get_collection(
        tf.GraphKeys.TRAINABLE_VARIABLES),
                                    max_to_keep=10)

    sess = tf.InteractiveSession()

    if FLAGS.train == False:
        # change to original meta batch size when loading model.
        FLAGS.meta_batch_size = orig_meta_batch_size

    if FLAGS.train_update_batch_size == -1:
        FLAGS.train_update_batch_size = FLAGS.update_batch_size
    if FLAGS.train_update_lr == -1:
        FLAGS.train_update_lr = FLAGS.update_lr

    exp_string = 'cls_' + str(FLAGS.num_classes) + '.mbs_' + str(
        FLAGS.meta_batch_size) + '.ubs_' + str(
            FLAGS.train_update_batch_size) + '.numstep' + str(
                FLAGS.num_updates) + '.updatelr' + str(FLAGS.train_update_lr)

    if FLAGS.num_filters != 64:
        exp_string += 'hidden' + str(FLAGS.num_filters)
    if FLAGS.max_pool:
        exp_string += 'maxpool'
    if FLAGS.stop_grad:
        exp_string += 'stopgrad'
    if FLAGS.baseline:
        exp_string += FLAGS.baseline
    if FLAGS.norm == 'batch_norm':
        exp_string += 'batchnorm'
    elif FLAGS.norm == 'layer_norm':
        exp_string += 'layernorm'
    elif FLAGS.norm == 'None':
        exp_string += 'nonorm'
    else:
        print('Norm setting not recognized.')

    resume_itr = 0
    model_file = None

    tf.global_variables_initializer().run()
    tf.train.start_queue_runners()

    if FLAGS.resume or not FLAGS.train:
        model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' +
                                                exp_string)
        if FLAGS.test_iter > 0:
            model_file = model_file[:model_file.index('model'
                                                      )] + 'model' + str(
                                                          FLAGS.test_iter)
        if model_file:
            ind1 = model_file.index('model')
            resume_itr = int(model_file[ind1 + 5:])
            print("Restoring model weights from " + model_file)
            saver.restore(sess, model_file)

    if FLAGS.train:
        train(model, saver, sess, exp_string, data_generator, resume_itr)
    else:
        test(model, saver, sess, exp_string, data_generator, test_num_updates)
コード例 #13
0
def main():
    if FLAGS.datasource == 'sinusoid':  #数据源为正弦波
        if FLAGS.train:
            test_num_updates = 5  #训练期间至少更新5次
        else:
            test_num_updates = 10  #测试期间至少更新10次
    else:
        if FLAGS.datasource == 'miniimagenet':  #数据源为'miniimagenet'
            if FLAGS.train == True:
                test_num_updates = 1  # 训练期间至少更新一次
            else:
                test_num_updates = 10
        else:
            test_num_updates = 10

    if FLAGS.train == False:  #测试时
        orig_meta_batch_size = FLAGS.meta_batch_size
        # 测试时,始终使用元批量大小为1。
        FLAGS.meta_batch_size = 1

    if FLAGS.datasource == 'sinusoid':
        data_generator = DataGenerator(FLAGS.update_batch_size * 2,
                                       FLAGS.meta_batch_size)
    else:
        if FLAGS.metatrain_iterations == 0 and FLAGS.datasource == 'miniimagenet':
            assert FLAGS.meta_batch_size == 1
            assert FLAGS.update_batch_size == 1
            data_generator = DataGenerator(1,
                                           FLAGS.meta_batch_size)  # 只使用一个数据点,

        else:
            if FLAGS.datasource == 'miniimagenet':  # TODO - use 15 val examples for imagenet?
                if FLAGS.train:
                    data_generator = DataGenerator(
                        FLAGS.update_batch_size + 15,
                        FLAGS.meta_batch_size)  # 仅使用一个数据点进行测试以保存内存
                else:
                    data_generator = DataGenerator(
                        FLAGS.update_batch_size * 2,
                        FLAGS.meta_batch_size)  # 仅使用一个数据点进行测试以保存内存
            else:
                data_generator = DataGenerator(
                    FLAGS.update_batch_size * 2,
                    FLAGS.meta_batch_size)  # 仅使用一个数据点进行测试以保存内存

    dim_output = data_generator.dim_output
    if FLAGS.baseline == 'oracle':
        assert FLAGS.datasource == 'sinusoid'
        dim_input = 3
        FLAGS.pretrain_iterations += FLAGS.metatrain_iterations
        FLAGS.metatrain_iterations = 0
    else:
        dim_input = data_generator.dim_input

    if FLAGS.datasource == 'miniimagenet' or FLAGS.datasource == 'omniglot':  #数据源为'miniimagenet'或'omniglot'时
        tf_data_load = True
        num_classes = data_generator.num_classes

        if FLAGS.train:  # 只有在需要时才能建立训练模型
            random.seed(5)
            image_tensor, label_tensor = data_generator.make_data_tensor(
            )  #读取数据
            inputa = tf.slice(image_tensor, [0, 0, 0],
                              [-1, num_classes * FLAGS.update_batch_size, -1
                               ])  #tf.slice(inputs, begin, size, name)
            inputb = tf.slice(image_tensor,
                              [0, num_classes * FLAGS.update_batch_size, 0],
                              [-1, -1, -1])
            labela = tf.slice(label_tensor, [0, 0, 0],
                              [-1, num_classes * FLAGS.update_batch_size, -1])
            labelb = tf.slice(label_tensor,
                              [0, num_classes * FLAGS.update_batch_size, 0],
                              [-1, -1, -1])
            input_tensors = {
                'inputa': inputa,
                'inputb': inputb,
                'labela': labela,
                'labelb': labelb
            }

        random.seed(6)
        image_tensor, label_tensor = data_generator.make_data_tensor(
            train=False)
        inputa = tf.slice(image_tensor, [0, 0, 0],
                          [-1, num_classes * FLAGS.update_batch_size, -1])
        inputb = tf.slice(image_tensor,
                          [0, num_classes * FLAGS.update_batch_size, 0],
                          [-1, -1, -1])
        labela = tf.slice(label_tensor, [0, 0, 0],
                          [-1, num_classes * FLAGS.update_batch_size, -1])
        labelb = tf.slice(label_tensor,
                          [0, num_classes * FLAGS.update_batch_size, 0],
                          [-1, -1, -1])
        metaval_input_tensors = {
            'inputa': inputa,
            'inputb': inputb,
            'labela': labela,
            'labelb': labelb
        }
    else:
        tf_data_load = False
        input_tensors = None

    model = MAML(dim_input, dim_output, test_num_updates=test_num_updates)
    if FLAGS.train or not tf_data_load:
        model.construct_model(input_tensors=input_tensors, prefix='metatrain_')
    if tf_data_load:
        model.construct_model(input_tensors=metaval_input_tensors,
                              prefix='metaval_')
    model.summ_op = tf.summary.merge_all()

    saver = loader = tf.train.Saver(tf.get_collection(
        tf.GraphKeys.TRAINABLE_VARIABLES),
                                    max_to_keep=10)

    sess = tf.InteractiveSession()

    if FLAGS.train == False:
        # 加载模型时更改为原始元批次大小
        FLAGS.meta_batch_size = orig_meta_batch_size

    if FLAGS.train_update_batch_size == -1:
        FLAGS.train_update_batch_size = FLAGS.update_batch_size
    if FLAGS.train_update_lr == -1:
        FLAGS.train_update_lr = FLAGS.update_lr

        exp_string = 'cls_' + str(FLAGS.num_classes) + '.mbs_' + str(
            FLAGS.meta_batch_size) + '.ubs_' + str(
                FLAGS.train_update_batch_size) + '.numstep' + str(
                    FLAGS.num_updates) + '.updatelr' + str(
                        FLAGS.train_update_lr)

    if FLAGS.num_filters != 64:
        exp_string += 'hidden' + str(FLAGS.num_filters)
    if FLAGS.max_pool:
        exp_string += 'maxpool'
    if FLAGS.stop_grad:
        exp_string += 'stopgrad'
    if FLAGS.baseline:
        exp_string += FLAGS.baseline
    if FLAGS.norm == 'batch_norm':
        exp_string += 'batchnorm'
    elif FLAGS.norm == 'layer_norm':
        exp_string += 'layernorm'
    elif FLAGS.norm == 'None':
        exp_string += 'nonorm'
    else:
        print('Norm setting not recognized.')

    resume_itr = 0
    model_file = None

    tf.global_variables_initializer().run()
    tf.train.start_queue_runners()

    if FLAGS.resume or not FLAGS.train:
        model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' +
                                                exp_string)
        if FLAGS.test_iter > 0:
            model_file = model_file[:model_file.index('model'
                                                      )] + 'model' + str(
                                                          FLAGS.test_iter)
        if model_file:
            ind1 = model_file.index('model')
            resume_itr = int(model_file[ind1 + 5:])
            print("Restoring model weights from " + model_file)
            saver.restore(sess, model_file)

    if FLAGS.train:
        train(model, saver, sess, exp_string, data_generator, resume_itr)
    else:
        test(model, saver, sess, exp_string, data_generator, test_num_updates)
コード例 #14
0
ファイル: main.py プロジェクト: jayvischeng/MetaLearning
def main():
    if FLAGS.datasource == 'sinusoid':
        if FLAGS.train:
            test_num_updates = 5
        else:
            test_num_updates = 10
    if FLAGS.train == False:
        orig_meta_batch_size = FLAGS.meta_batch_size
        # always use meta batch size of 1 when testing.
        FLAGS.meta_batch_size = 1
    if FLAGS.datasource == 'sinusoid':
        # load data
        data_generator = DataGenerator(
            FLAGS.update_batch_size * 2,
            FLAGS.meta_batch_size)  # k=update_batch_size*2=10

    dim_output = data_generator.dim_output
    dim_input = data_generator.dim_input
    tf_data_load = False
    input_tensors = None

    # construct meta learning model
    model = MAML(dim_input, dim_output, test_num_updates=test_num_updates)
    model.construct_model(input_tensors=input_tensors, prefix='metatrain_')

    model.summ_op = tf.summary.merge_all()
    saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES),
                           max_to_keep=10)
    sess = tf.InteractiveSession()
    # writer = tf.summary.FileWriter("../../../logs", sess.graph)
    if FLAGS.train == False:
        # change to original meta batch size when loading model.
        FLAGS.meta_batch_size = orig_meta_batch_size

    if FLAGS.train_update_batch_size == -1:
        FLAGS.train_update_batch_size = FLAGS.update_batch_size
    if FLAGS.train_update_lr == -1:
        FLAGS.train_update_lr = FLAGS.update_lr

    exp_string = 'cls_' + str(1) + '.mbs_' + str(
        FLAGS.meta_batch_size) + '.ubs_' + str(
            FLAGS.train_update_batch_size) + '.numstep' + str(
                FLAGS.num_updates) + '.updatelr' + str(FLAGS.train_update_lr)

    if FLAGS.stop_grad:
        exp_string += 'stopgrad'
    if FLAGS.baseline:
        exp_string += FLAGS.baseline
    if FLAGS.norm == 'batch_norm':
        exp_string += 'batchnorm'
    elif FLAGS.norm == 'layer_norm':
        exp_string += 'layernorm'
    elif FLAGS.norm == 'None':
        exp_string += 'nonorm'
    else:
        print('Norm setting not recognized.')

    resume_itr = 0
    model_file = None
    tf.global_variables_initializer().run()
    tf.train.start_queue_runners()
    print("exp_string is: ", exp_string)

    if FLAGS.resume or not FLAGS.train:
        model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' +
                                                exp_string)
        if FLAGS.test_iter > 0:
            model_file = model_file[:model_file.index('model'
                                                      )] + 'model' + str(
                                                          FLAGS.test_iter)
        if model_file:
            ind1 = model_file.index('model')
            resume_itr = int(model_file[ind1 + 5:])
            print("Restoring model weights from " + model_file)
            saver.restore(sess, model_file)

    if FLAGS.train:
        train(model, saver, sess, exp_string, data_generator, resume_itr)
    else:
        test(model, saver, sess, exp_string, data_generator, test_num_updates)

    sin_test(model, saver, sess, exp_string, data_generator, test_num_updates)
コード例 #15
0
def main():

    if FLAGS.train == False:
        orig_meta_batch_size = FLAGS.meta_batch_size
        # always use meta batch size of 1 when testing.
        FLAGS.meta_batch_size = 1

    if FLAGS.datasource == 'sinusoid':
        # ME: update_batch_size = 10 (20 samples/task); meta_batch_size = 25 (25 tasks)
        data_generator = DataGenerator(FLAGS.update_batch_size * 2,
                                       FLAGS.meta_batch_size,
                                       datasource='sinusoid')
    elif FLAGS.datasource == 'ball':
        # ME: update_batch_size = 10 (20 samples/task); meta_batch_size = 25 (25 tasks)
        data_generator = DataGenerator(FLAGS.update_batch_size * 2,
                                       FLAGS.meta_batch_size,
                                       datasource='ball')
    elif FLAGS.datasource == 'ball_file':
        # ME: update_batch_size = 10 (20 samples/task); meta_batch_size = 25 (25 tasks)
        data_generator = DataGenerator(FLAGS.update_batch_size * 2,
                                       FLAGS.meta_batch_size,
                                       datasource='ball_file')
    else:  # 'rect_file"
        # ME: update_batch_size = 10 (20 samples/task); meta_batch_size = 25 (25 tasks)
        data_generator = DataGenerator(FLAGS.update_batch_size * 2,
                                       FLAGS.meta_batch_size,
                                       datasource='rect_file',
                                       rect_truncated=rect_truncated)

    dim_output = data_generator.dim_output
    dim_input = data_generator.dim_input

    model = MAML(dim_input, dim_output, test_num_updates=test_num_updates)
    model.construct_model()

    model.summ_op = tf.summary.merge_all()

    saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES),
                           max_to_keep=10)
    sess = tf.InteractiveSession()

    if FLAGS.train == False:
        # change to original meta batch size when loading model.
        FLAGS.meta_batch_size = orig_meta_batch_size

    exp_string = get_exp_string(model)
    resume_itr = 0

    tf.global_variables_initializer().run()
    tf.train.start_queue_runners()

    if FLAGS.resume or not FLAGS.train:
        model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' +
                                                exp_string)
        if FLAGS.test_iter > 0:
            model_file = model_file[:model_file.index('model'
                                                      )] + 'model' + str(
                                                          FLAGS.test_iter)
        if model_file:
            ind1 = model_file.index('model')
            resume_itr = int(model_file[ind1 + 5:])
            print("Restoring model weights from " + model_file)
            saver.restore(sess, model_file)

    if FLAGS.train:
        train(model, saver, sess, exp_string, data_generator, resume_itr)
    else:
        # ME: test_num_updates = 10; 10 gradient updates
        test(model, saver, sess, exp_string, data_generator, test_num_updates)
コード例 #16
0
    def construct_model(self):
        self.sess = tf.InteractiveSession()
        if FLAGS.train == False:
            orig_meta_batch_size = FLAGS.meta_batch_size
            # always use meta batch size of 1 when testing.
            FLAGS.meta_batch_size = 1

        if FLAGS.datasource in ['sinusoid', 'mixture']:
            data_generator = DataGenerator(
                FLAGS.update_batch_size + FLAGS.update_batch_size_eval,
                FLAGS.meta_batch_size)
        else:
            if FLAGS.metatrain_iterations == 0 and FLAGS.datasource in [
                    'miniimagenet', 'multidataset'
            ]:
                assert FLAGS.meta_batch_size == 1
                assert FLAGS.update_batch_size == 1
                data_generator = DataGenerator(
                    1, FLAGS.meta_batch_size)  # only use one datapoint,
            else:
                if FLAGS.datasource in [
                        'miniimagenet', 'multidataset'
                ]:  # TODO - use 15 val examples for imagenet?
                    if FLAGS.train:
                        data_generator = DataGenerator(
                            FLAGS.update_batch_size + 15, FLAGS.meta_batch_size
                        )  # only use one datapoint for testing to save memory
                    else:
                        data_generator = DataGenerator(
                            FLAGS.update_batch_size * 2, FLAGS.meta_batch_size
                        )  # only use one datapoint for testing to save memory
                else:
                    data_generator = DataGenerator(
                        FLAGS.update_batch_size * 2, FLAGS.meta_batch_size
                    )  # only use one datapoint for testing to save memory

        dim_output = data_generator.dim_output

        dim_input = data_generator.dim_input

        if FLAGS.datasource in ['miniimagenet', 'omniglot', 'multidataset']:
            tf_data_load = True
            num_classes = data_generator.num_classes

            if FLAGS.train:  # only construct training model if needed
                random.seed(5)
                if FLAGS.datasource in ['miniimagenet', 'omniglot']:
                    image_tensor, label_tensor = data_generator.make_data_tensor(
                    )
                elif FLAGS.datasource == 'multidataset':
                    image_tensor, label_tensor = data_generator.make_data_tensor_multidataset(
                        sel_num=self.clusters, train=True)
                inputa = tf.slice(
                    image_tensor, [0, 0, 0],
                    [-1, num_classes * FLAGS.update_batch_size, -1])
                inputb = tf.slice(
                    image_tensor,
                    [0, num_classes * FLAGS.update_batch_size, 0],
                    [-1, -1, -1])
                labela = tf.slice(
                    label_tensor, [0, 0, 0],
                    [-1, num_classes * FLAGS.update_batch_size, -1])
                labelb = tf.slice(
                    label_tensor,
                    [0, num_classes * FLAGS.update_batch_size, 0],
                    [-1, -1, -1])
                input_tensors = {
                    'inputa': inputa,
                    'inputb': inputb,
                    'labela': labela,
                    'labelb': labelb
                }

            random.seed(6)
            if FLAGS.datasource in ['miniimagenet', 'omniglot']:
                image_tensor, label_tensor = data_generator.make_data_tensor(
                    train=False)
            elif FLAGS.datasource == 'multidataset':
                image_tensor, label_tensor = data_generator.make_data_tensor_multidataset(
                    sel_num=self.clusters, train=False)
            inputa = tf.slice(image_tensor, [0, 0, 0],
                              [-1, num_classes * FLAGS.update_batch_size, -1])
            inputb = tf.slice(image_tensor,
                              [0, num_classes * FLAGS.update_batch_size, 0],
                              [-1, -1, -1])
            labela = tf.slice(label_tensor, [0, 0, 0],
                              [-1, num_classes * FLAGS.update_batch_size, -1])
            labelb = tf.slice(label_tensor,
                              [0, num_classes * FLAGS.update_batch_size, 0],
                              [-1, -1, -1])
            metaval_input_tensors = {
                'inputa': inputa,
                'inputb': inputb,
                'labela': labela,
                'labelb': labelb
            }
        else:
            tf_data_load = False
            input_tensors = None

        model = MAML(self.sess,
                     dim_input,
                     dim_output,
                     test_num_updates=self.test_num_updates)

        model.cluster_layer_0 = self.clusters

        if FLAGS.train or not tf_data_load:
            model.construct_model(input_tensors=input_tensors,
                                  prefix='metatrain_')
        if tf_data_load:
            model.construct_model(input_tensors=metaval_input_tensors,
                                  prefix='metaval_')
        model.summ_op = tf.summary.merge_all()
        saver = loader = tf.train.Saver(tf.get_collection(
            tf.GraphKeys.TRAINABLE_VARIABLES),
                                        max_to_keep=10)

        if FLAGS.train == False:
            # change to original meta batch size when loading model.
            FLAGS.meta_batch_size = orig_meta_batch_size

        if FLAGS.train_update_batch_size == -1:
            FLAGS.train_update_batch_size = FLAGS.update_batch_size
        if FLAGS.train_update_lr == -1:
            FLAGS.train_update_lr = FLAGS.update_lr

        return model, saver, data_generator
コード例 #17
0
ファイル: main.py プロジェクト: augustdemi/demi
def main():
    os.environ["CUDA_VISIBLE_DEVICES"] = FLAGS.gpu

    if FLAGS.train == False:
        orig_meta_batch_size = FLAGS.meta_batch_size
        # always use meta batch size of 1 when testing.
        FLAGS.meta_batch_size = 1

    data_generator = DataGenerator()

    dim_output = data_generator.num_classes
    dim_input = data_generator.dim_input

    inputa, inputb, labela, labelb = data_generator.make_data_tensor()
    metatrain_input_tensors = {'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb}

    # pred_weights = data_generator.pred_weights
    model = MAML(dim_input, dim_output)
    model.construct_model(input_tensors=metatrain_input_tensors, prefix='metatrain_')
    model.summ_op = tf.summary.merge_all()

    saver = loader = tf.train.Saver(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES), max_to_keep=20)

    sess = tf.InteractiveSession()


    if not FLAGS.train:
        # change to original meta batch size when loading model.
        FLAGS.meta_batch_size = orig_meta_batch_size

    if FLAGS.train_update_batch_size == -1:
        FLAGS.train_update_batch_size = FLAGS.update_batch_size
    if FLAGS.train_update_lr == -1:
        FLAGS.train_update_lr = FLAGS.update_lr

    trained_model_dir = 'cls_' + str(FLAGS.num_classes) + '.mbs_' + str(FLAGS.meta_batch_size) + '.ubs_' + str(
        FLAGS.train_update_batch_size) + '.numstep' + str(FLAGS.num_updates) + '.updatelr' + str(
        FLAGS.train_update_lr) + '.metalr' + str(FLAGS.meta_lr)

    print(">>>>> trained_model_dir: ", FLAGS.logdir + '/' + trained_model_dir)


    resume_itr = 0

    tf.global_variables_initializer().run()
    tf.train.start_queue_runners()

    print("================================================================================")
    print('initial weights norm: ', np.linalg.norm(sess.run('model/w1:0')))
    print('initial last weights: ', sess.run('model/w1:0')[-1])
    print('initial bias: ', sess.run('model/b1:0'))
    print("================================================================================")





    ################## Train ##################

    if FLAGS.resume:
        model_file = None
        if FLAGS.model.startswith('m2'):
            trained_model_dir = 'sbjt' + str(FLAGS.sbjt_start_idx) + '.ubs_' + str(
                FLAGS.train_update_batch_size) + '.numstep' + str(FLAGS.num_updates) + '.updatelr' + str(
                FLAGS.train_update_lr) + '.metalr' + str(FLAGS.meta_lr)
        model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' + trained_model_dir)
        print(">>>>> trained_model_dir: ", FLAGS.logdir + '/' + trained_model_dir)

        w = None
        b = None
        print(">>>> model_file1: ", model_file)

        if model_file:
            if FLAGS.test_iter > 0:
                files = os.listdir(model_file[:model_file.index('model')])
                if 'model' + str(FLAGS.test_iter) + '.index' in files:
                    model_file = model_file[:model_file.index('model')] + 'model' + str(FLAGS.test_iter)
                    print(">>>> model_file2: ", model_file)
            print("1. Restoring model weights from " + model_file)
            saver.restore(sess, model_file)
            b = sess.run('model/b1:0').tolist()
            print("updated weights from ckpt: ", np.array(b))
            ind1 = model_file.index('model')
            resume_itr = int(model_file[ind1 + 5:])

    elif FLAGS.keep_train_dir:  # when the model needs to be initialized from another model.
        resume_itr = 0
        print('resume_itr: ', resume_itr)
        model_file = tf.train.latest_checkpoint(FLAGS.keep_train_dir)
        print(">>>>> base_model_dir: ", FLAGS.keep_train_dir)

        if FLAGS.test_iter > 0:
            files = os.listdir(model_file[:model_file.index('model')])
            if 'model' + str(FLAGS.test_iter) + '.index' in files:
                model_file = model_file[:model_file.index('model')] + 'model' + str(FLAGS.test_iter)
                print(">>>> model_file2: ", model_file)

        print("2. Restoring model weights from " + model_file)
        saver.restore(sess, model_file)
        print("updated weights from ckpt: ", sess.run('model/b1:0'))

    elif FLAGS.model.startswith('s4'):
        from feature_layers import feature_layer
        three_layers = feature_layer(10, 1)
        print('FLAGS.base_vae_model: ', FLAGS.base_vae_model)
        three_layers.model_intensity.load_weights(FLAGS.base_vae_model + '.h5')
        w = three_layers.model_intensity.layers[-1].get_weights()[0]
        b = three_layers.model_intensity.layers[-1].get_weights()[1]
        print('s2 b: ', b)
        print('s2 w: ', w)
        print('-----------------------------------------------------------------')
        with tf.variable_scope("model", reuse=True) as scope:
            scope.reuse_variables()
            b1 = tf.get_variable("b1", [1, 2]).assign(np.array(b))
            w1 = tf.get_variable("w1", [300, 1, 2]).assign(np.array(w))
            sess.run(b1)
            sess.run(w1)
        print("after: ", sess.run('model/b1:0'))
        print("after: ", sess.run('model/w1:0'))

    if not FLAGS.all_sub_model:
        trained_model_dir = 'sbjt' + str(FLAGS.sbjt_start_idx) + '.ubs_' + str(
            FLAGS.train_update_batch_size) + '.numstep' + str(FLAGS.num_updates) + '.updatelr' + str(
            FLAGS.train_update_lr) + '.metalr' + str(FLAGS.meta_lr)

    print("================================================================================")

    train(model, saver, sess, trained_model_dir, metatrain_input_tensors, resume_itr)

    end_time = datetime.now()
    elapse = end_time - start_time
    print("================================================================================")
    print(">>>>>> elapse time: " + str(elapse))
    print("================================================================================")
コード例 #18
0
def main():
    if FLAGS.train:
        test_num_updates = 20
    elif FLAGS.from_scratch:
        test_num_updates = 200
    else:
        test_num_updates = 50

    if FLAGS.train == False:
        orig_meta_batch_size = FLAGS.meta_batch_size
        # always use meta batch size of 1 when testing.
        FLAGS.meta_batch_size = 1

    sess = tf.InteractiveSession()

    if not FLAGS.dataset == 'imagenet':
        data_generator = DataGenerator(FLAGS.inner_update_batch_size_train + FLAGS.outer_update_batch_size,
                                       FLAGS.inner_update_batch_size_val + FLAGS.outer_update_batch_size,
                                       FLAGS.meta_batch_size)
    else:
        data_generator = DataGeneratorImageNet(FLAGS.inner_update_batch_size_train + FLAGS.outer_update_batch_size,
                                               FLAGS.inner_update_batch_size_val + FLAGS.outer_update_batch_size,
                                               FLAGS.meta_batch_size)

    dim_output_train = data_generator.dim_output_train
    dim_output_val = data_generator.dim_output_val
    dim_input = data_generator.dim_input


    tf_data_load = True
    num_classes_train = data_generator.num_classes_train
    num_classes_val = data_generator.num_classes_val

    if FLAGS.train: # only construct training model if needed
        random.seed(5)
        image_tensor, label_tensor = data_generator.make_data_tensor()
        inputa = tf.slice(image_tensor, [0,0,0], [-1,num_classes_train*FLAGS.inner_update_batch_size_train, -1])
        inputb = tf.slice(image_tensor, [0,num_classes_train*FLAGS.inner_update_batch_size_train, 0], [-1,-1,-1])
        labela = tf.slice(label_tensor, [0,0,0], [-1,num_classes_train*FLAGS.inner_update_batch_size_train, -1])
        labelb = tf.slice(label_tensor, [0,num_classes_train*FLAGS.inner_update_batch_size_train, 0], [-1,-1,-1])
        input_tensors = {'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb}

    random.seed(6)
    image_tensor, label_tensor = data_generator.make_data_tensor(train=False)
    inputa = tf.slice(image_tensor, [0,0,0], [-1,num_classes_val*FLAGS.inner_update_batch_size_val, -1])
    inputb = tf.slice(image_tensor, [0,num_classes_val*FLAGS.inner_update_batch_size_val, 0], [-1,-1,-1])
    labela = tf.slice(label_tensor, [0,0,0], [-1,num_classes_val*FLAGS.inner_update_batch_size_val, -1])
    labelb = tf.slice(label_tensor, [0,num_classes_val*FLAGS.inner_update_batch_size_val, 0], [-1,-1,-1])
    metaval_input_tensors = {'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb}

    model = MAML(dim_input, dim_output_train, dim_output_val, test_num_updates=test_num_updates)
    if FLAGS.train or not tf_data_load:
        model.construct_model(input_tensors=input_tensors, prefix='metatrain_')
    if tf_data_load:
        model.construct_model(input_tensors=metaval_input_tensors, prefix='metaval_')
    model.summ_op = tf.summary.merge_all()

    saver = loader = tf.train.Saver(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES), max_to_keep=10)

    if FLAGS.debug:
        sess = tf_debug.LocalCLIDebugWrapperSession(sess)

    if FLAGS.train == False:
        # change to original meta batch size when loading model.
        FLAGS.meta_batch_size = orig_meta_batch_size

    if FLAGS.log_inner_update_batch_size_val == -1:
        FLAGS.log_inner_update_batch_size_val = FLAGS.inner_update_batch_size_val
    if FLAGS.train_update_lr == -1:
        FLAGS.train_update_lr = FLAGS.update_lr

    exp_string = ''
    exp_string += '.nu_' + str(FLAGS.num_updates) + '.ilr_' + str(FLAGS.train_update_lr)
    if FLAGS.meta_lr != 0.001:
        exp_string += '.olr_' + str(FLAGS.meta_lr)
    if FLAGS.mt_mode != 'gtgt':
        if FLAGS.partition_algorithm == 'hyperplanes':
            exp_string += '.m_' + str(FLAGS.margin)
        if FLAGS.partition_algorithm == 'kmeans' or FLAGS.partition_algorithm == 'kmodes':
            exp_string += '.k_' + str(FLAGS.num_clusters)
            exp_string += '.p_' + str(FLAGS.num_partitions)
            if FLAGS.scaled_encodings and FLAGS.num_partitions != 1:
                exp_string += '.scaled'
        if FLAGS.mt_mode == 'encenc':
            exp_string += '.ned_' + str(FLAGS.num_encoding_dims)
        elif FLAGS.mt_mode == 'semi':
            exp_string += '.pgtgt_' + str(FLAGS.p_gtgt)
    exp_string += '.mt_' + FLAGS.mt_mode
    exp_string += '.mbs_' + str(FLAGS.meta_batch_size) + \
                  '.nct_' + str(FLAGS.num_classes_train) + \
                  '.iubst_' + str(FLAGS.inner_update_batch_size_train) + \
                    '.iubsv_' + str(FLAGS.log_inner_update_batch_size_val) + \
                    '.oubs' + str(FLAGS.outer_update_batch_size)
    exp_string = exp_string[1:]     # get rid of leading period

    if FLAGS.on_encodings:
        exp_string += '.onenc'
        exp_string += '.nhl_' + str(FLAGS.num_hidden_layers)
    if FLAGS.num_filters != 64:
        exp_string += '.hidden' + str(FLAGS.num_filters)
    if FLAGS.max_pool:
        exp_string += '.maxpool'
    if FLAGS.stop_grad:
        exp_string += '.stopgrad'
    if FLAGS.norm == 'batch_norm':
        exp_string += '.batchnorm'
    elif FLAGS.norm == 'layer_norm':
        exp_string += '.layernorm'
    elif FLAGS.norm == 'None':
        exp_string += '.nonorm'
    else:
        print('Norm setting not recognized.')
    if FLAGS.resnet:
        exp_string += '.res{}parts{}'.format(FLAGS.num_res_blocks, FLAGS.num_parts_per_res_block)
    if FLAGS.miniimagenet_only:
        exp_string += '.mini'
    if FLAGS.suffix != '':
        exp_string += '.' + FLAGS.suffix

    resume_itr = 0
    model_file = None

    tf.global_variables_initializer().run()

    print(exp_string)

    if FLAGS.resume or not FLAGS.train:
        model_file = tf.train.latest_checkpoint(logdir + '/' + exp_string)
        if FLAGS.test_iter > 0:
            model_file = model_file[:model_file.index('model')] + 'model' + str(FLAGS.test_iter)
        if model_file:
            ind1 = model_file.index('model')
            resume_itr = int(model_file[ind1+5:])
            print("Restoring model weights from " + model_file)
            saver.restore(sess, model_file)
        else:
            print("No checkpoint found")

    if FLAGS.from_scratch:
        exp_string = ''

    if FLAGS.from_scratch and not os.path.isdir(logdir):
        os.makedirs(logdir)

    if FLAGS.train:
        train(model, saver, sess, exp_string, data_generator, resume_itr)
    else:
        test(model, saver, sess, exp_string, data_generator, test_num_updates)
コード例 #19
0
ファイル: main.py プロジェクト: csyanbin/Meta-SGD
def main():
    if FLAGS.datasource == 'sinusoid':
        if FLAGS.train:
            test_num_updates = 5
        else:
            test_num_updates = 10
    else:
        if FLAGS.datasource == 'miniimagenet':
            if FLAGS.train == True:
                test_num_updates = 1  # eval on at least one update during training
            else:
                test_num_updates = 10
        else:
            test_num_updates = 10

    if FLAGS.train == False:
        orig_meta_batch_size = FLAGS.meta_batch_size
        # always use meta batch size of 1 when testing.
        FLAGS.meta_batch_size = 1

    # if FLAGS.datasource == 'sinusoid':
    #     data_generator = DataGenerator(FLAGS.update_batch_size*2, FLAGS.meta_batch_size)
    # else:
    #     if FLAGS.metatrain_iterations == 0 and FLAGS.datasource == 'miniimagenet':
    #         assert FLAGS.meta_batch_size == 1
    #         assert FLAGS.update_batch_size == 1
    #         data_generator = DataGenerator(1, FLAGS.meta_batch_size)  # only use one datapoint,
    #     else:
    #         if FLAGS.datasource == 'miniimagenet': # TODO - use 15 val examples for imagenet?
    #             if FLAGS.train:
    #                 data_generator = DataGenerator(FLAGS.update_batch_size+15, FLAGS.meta_batch_size)  # only use one datapoint for testing to save memory
    #             else:
    #                 data_generator = DataGenerator(FLAGS.update_batch_size*2, FLAGS.meta_batch_size)  # only use one datapoint for testing to save memory
    #         else:
    #             data_generator = DataGenerator(FLAGS.update_batch_size*2, FLAGS.meta_batch_size)  # only use one datapoint for testing to save memory

    dim_output = FLAGS.num_classes
    if FLAGS.baseline == 'oracle':
        assert FLAGS.datasource == 'sinusoid'
        dim_input = 3
        FLAGS.pretrain_iterations += FLAGS.metatrain_iterations
        FLAGS.metatrain_iterations = 0
    else:
        dim_input = 84 * 84 * 3

    model = MAML(dim_input, dim_output, test_num_updates=test_num_updates)
    if FLAGS.train:
        model.construct_model(input_tensors=None, prefix='metatrain_')
    else:
        model.construct_model(input_tensors=None, prefix='metaval_')
    model.summ_op = tf.summary.merge_all()

    saver = loader = tf.train.Saver(tf.get_collection(
        tf.GraphKeys.TRAINABLE_VARIABLES),
                                    max_to_keep=40)
    for var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):
        print(var.name)

    sess = tf.InteractiveSession()

    if FLAGS.train == False:
        # change to original meta batch size when loading model.
        FLAGS.meta_batch_size = orig_meta_batch_size

    if FLAGS.train_update_batch_size == -1:
        FLAGS.train_update_batch_size = FLAGS.update_batch_size
    if FLAGS.train_update_lr == -1:
        FLAGS.train_update_lr = FLAGS.update_lr

    exp_string = 'cls_' + str(FLAGS.num_classes) + '.mbs_' + str(
        FLAGS.meta_batch_size) + '.ubs_' + str(
            FLAGS.train_update_batch_size) + '.numstep' + str(
                FLAGS.num_updates) + '.updatelr' + str(FLAGS.train_update_lr)

    if FLAGS.num_filters != 64:
        exp_string += 'hidden' + str(FLAGS.num_filters)
    if FLAGS.max_pool:
        exp_string += 'maxpool'
    if FLAGS.stop_grad:
        exp_string += 'stopgrad'
    if FLAGS.baseline:
        exp_string += FLAGS.baseline
    if FLAGS.norm == 'batch_norm':
        exp_string += 'batchnorm'
    elif FLAGS.norm == 'layer_norm':
        exp_string += 'layernorm'
    elif FLAGS.norm == 'None':
        exp_string += 'nonorm'
    else:
        print('Norm setting not recognized.')
    if FLAGS.lr_mode > 0:
        exp_string += 'lrmode' + str(FLAGS.lr_mode)
    print(exp_string)

    resume_itr = 0
    model_file = None

    tf.global_variables_initializer().run()
    #tf.train.start_queue_runners()

    if FLAGS.resume or not FLAGS.train:
        model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' +
                                                exp_string)
        if FLAGS.test_iter > 0:
            model_file = model_file[:model_file.index('model'
                                                      )] + 'model' + str(
                                                          FLAGS.test_iter)
        if model_file:
            ind1 = model_file.index('model')
            resume_itr = int(model_file[ind1 + 5:])
            print("Restoring model weights from " + model_file)
            saver.restore(sess, model_file)

    if FLAGS.train:
        train(model, saver, sess, exp_string, resume_itr)
    else:
        test(model, saver, sess, exp_string, test_num_updates)
コード例 #20
0
def main():
    data_generator = DataGenerator(FLAGS.update_batch_size,
                                   FLAGS.meta_batch_size,
                                   k_shot=FLAGS.k_shot)

    dim_output = data_generator.dim_output
    dim_input = data_generator.dim_input

    if FLAGS.datasource == 'ml':
        input_tensors = {
            'inputa': tf.placeholder(tf.int32, shape=[None, None, 2]),
            'inputb': tf.placeholder(tf.int32, shape=[None, None, 2]),
            'labela': tf.placeholder(tf.float32, shape=[None, None, 1]),
            'labelb': tf.placeholder(tf.float32, shape=[None, None, 1])
        }
    elif FLAGS.datasource == 'bpr' or FLAGS.datasource == 'bpr_time':
        input_tensors = {
            'inputa': tf.placeholder(tf.int32, shape=[None, None, 3]),
            'inputb': tf.placeholder(tf.int32, shape=[None, None, 3]),
        }
    else:
        raise Exception('non-supported data source: {}'.format(
            FLAGS.datasource))

    model = MAML(dim_input, dim_output)
    if FLAGS.train or FLAGS.test_existing_user:
        model.construct_model(input_tensors=input_tensors, prefix='metatrain_')
    else:
        model.construct_model(input_tensors=input_tensors, prefix='META_TEST')
    model.summ_op = tf.summary.merge_all()

    saver = loader = tf.train.Saver(tf.get_collection(
        tf.GraphKeys.TRAINABLE_VARIABLES),
                                    max_to_keep=10)

    sess = tf.InteractiveSession()

    exp_string = 'mtype_{}.mbs_{}.ubs_{}.meta_lr_{}.' \
                 'update_step_{}.update_lr_{}.' \
                 'lambda_lr_{}.avg_f_{}' \
                 '.time_{}'.format(FLAGS.datasource,
                                   FLAGS.meta_batch_size,
                                   FLAGS.update_batch_size,
                                   FLAGS.meta_lr, FLAGS.num_updates,
                                   FLAGS.update_lr,
                                   FLAGS.lambda_lr,
                                   FLAGS.use_avg_init,
                                   str(datetime.now()))

    resume_itr = 0
    tf.global_variables_initializer().run()
    tf.train.start_queue_runners()
    if FLAGS.resume:
        model_path = '{}/mlRRS/model/{}/model_{}'.format(
            FLAGS.logdir, FLAGS.load_dir, FLAGS.resume_iter)
        if os.path.exists(model_path + '.meta'):
            loader.restore(sess=sess, save_path=model_path)
            resume_itr = FLAGS.resume_iter
        else:
            raise Exception('No model saved at path {}'.format(model_path))
    if FLAGS.train:
        train(model, saver, sess, exp_string, data_generator, resume_itr)
    if FLAGS.test_existing_user:
        test_existing_user(model, saver, sess, exp_string, data_generator,
                           resume_itr)
    if FLAGS.test:
        test(model, saver, sess, exp_string, data_generator, resume_itr)
コード例 #21
0
def main():
    if FLAGS.datasource == 'sinusoid':
        if FLAGS.train:
            test_num_updates = 5
        else:
            test_num_updates = 10
    else:
        if FLAGS.datasource == 'miniimagenet' or FLAGS.datasource == 'cifar100':
            if FLAGS.train == True:
                test_num_updates = 1  # eval on at least one update during training
            else:
                test_num_updates = 10
        else:
            test_num_updates = 10

    if FLAGS.train == False:
        orig_meta_batch_size = FLAGS.meta_batch_size
        # always use meta batch size of 1 when testing.
        FLAGS.meta_batch_size = 1

    if FLAGS.datasource == 'sinusoid':
        data_generator = DataGenerator(FLAGS.update_batch_size*2, FLAGS.meta_batch_size)
    else:
        if FLAGS.metatrain_iterations == 0: #and FLAGS.datasource == 'miniimagenet':
            assert FLAGS.meta_batch_size == 1
            assert FLAGS.update_batch_size == 1
            data_generator = DataGenerator(1, FLAGS.meta_batch_size)  # only use one datapoint,
        else:
            if FLAGS.datasource == 'miniimagenet': #or FLAGS.datasource == 'cifar100': # TODO - use 15 val examples for imagenet?
                if FLAGS.train:
                    data_generator = DataGenerator(FLAGS.update_batch_size+15, FLAGS.meta_batch_size)  # only use one datapoint for testing to save memory
                else:
                    data_generator = DataGenerator(FLAGS.update_batch_size*2, FLAGS.meta_batch_size)  # only use one datapoint for testing to save memory
            else:
                data_generator = DataGenerator(FLAGS.update_batch_size*2, FLAGS.meta_batch_size)  # only use one datapoint for testing to save memory


    dim_output = data_generator.dim_output

    if FLAGS.baseline == 'oracle':
        assert FLAGS.datasource == 'sinusoid'
        dim_input = 3
        FLAGS.pretrain_iterations += FLAGS.metatrain_iterations
        FLAGS.metatrain_iterations = 0
    else:
        dim_input = data_generator.dim_input


    random.seed(7)
    X = data_generator.make_autoencoder_data_tensor(train=True)
    Y = data_generator.make_autoencoder_data_tensor(train=False)
    autoencoder_input_tensors = {'X': X, 'Y': Y}
    dim_s = 32

    autoencoder = Autoencoder(dim_input, dim_s)
    if FLAGS.train:
        autoencoder.construct_autoencoder(input_tensors=autoencoder_input_tensors, prefix='autoencoder_train')
    else:
        autoencoder.construct_autoencoder(input_tensors=autoencoder_input_tensors, prefix='autoencoder_test')



    if FLAGS.datasource == 'miniimagenet' or FLAGS.datasource == 'omniglot' or FLAGS.datasource == 'cifar100':
        tf_data_load = True
        num_classes = data_generator.num_classes

        if FLAGS.train: # only construct training model if needed
            random.seed(5)
            image_tensor, label_tensor = data_generator.make_data_tensor()
            inputa = tf.slice(image_tensor, [0,0,0], [-1,num_classes*FLAGS.update_batch_size, -1])
            inputb = tf.slice(image_tensor, [0,num_classes*FLAGS.update_batch_size, 0], [-1,-1,-1])
            labela = tf.slice(label_tensor, [0,0,0], [-1,num_classes*FLAGS.update_batch_size, -1])
            labelb = tf.slice(label_tensor, [0,num_classes*FLAGS.update_batch_size, 0], [-1,-1,-1])

            # s_tensor = tf.map_fn(lambda x: autoencoder.encode(x), inputa[0])
            # s_tensor = tf.reshape(s_tensor, [s_tensor.get_shape()[0], -1])
            # s_tensor = tf.reduce_sum(s_tensor, 0)

            s_tensor = tf.map_fn(lambda x: make_s(x, autoencoder), inputa)

            input_tensors = {'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb, 's_tensor':s_tensor}

        random.seed(6)
        image_tensor, label_tensor = data_generator.make_data_tensor(train=False)
        inputa = tf.slice(image_tensor, [0,0,0], [-1,num_classes*FLAGS.update_batch_size, -1])
        inputb = tf.slice(image_tensor, [0,num_classes*FLAGS.update_batch_size, 0], [-1,-1,-1])
        labela = tf.slice(label_tensor, [0,0,0], [-1,num_classes*FLAGS.update_batch_size, -1])
        labelb = tf.slice(label_tensor, [0,num_classes*FLAGS.update_batch_size, 0], [-1,-1,-1])
        s_tensor = tf.map_fn(lambda x: make_s(x, autoencoder), inputa)

        metaval_input_tensors = {'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb, 's_tensor':s_tensor}
    else:
        tf_data_load = False
        input_tensors = None



    #autoencoder_for_maml = autoencoder.encode(input_tensors = input_tensors['inputa'])


    model = MAML(autoencoder.train_phase, dim_input, dim_output, test_num_updates=test_num_updates)
    if FLAGS.train or not tf_data_load:
        model.construct_model(input_tensors=input_tensors, prefix='metatrain_')
   # else:
    if tf_data_load:
        model.construct_model(input_tensors=metaval_input_tensors, prefix='metaval_')
    model.summ_op = tf.summary.merge_all()
    autoencoder.summ_op = tf.summary.merge_all()
    saver = loader = tf.train.Saver(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES), max_to_keep=10)

    sess = tf.InteractiveSession()


    if FLAGS.train == False:
        # change to original meta batch size when loading model.
        FLAGS.meta_batch_size = orig_meta_batch_size

    if FLAGS.train_update_batch_size == -1:
        FLAGS.train_update_batch_size = FLAGS.update_batch_size
    if FLAGS.train_update_lr == -1:
        FLAGS.train_update_lr = FLAGS.update_lr

    exp_string = 'cls_'+str(FLAGS.num_classes)+'.mbs_'+str(FLAGS.meta_batch_size) + '.ubs_' + str(FLAGS.train_update_batch_size) + '.numstep' + str(FLAGS.num_updates) + '.updatelr' + str(FLAGS.train_update_lr)

    if FLAGS.num_filters != 64:
        exp_string += 'hidden' + str(FLAGS.num_filters)
    if FLAGS.max_pool:
        exp_string += 'maxpool'
    if FLAGS.stop_grad:
        exp_string += 'stopgrad'
    if FLAGS.baseline:
        exp_string += FLAGS.baseline
    if FLAGS.norm == 'batch_norm':
        exp_string += 'batchnorm'
    elif FLAGS.norm == 'layer_norm':
        exp_string += 'layernorm'
    elif FLAGS.norm == 'None':
        exp_string += 'nonorm'
    else:
        print('Norm setting not recognized.')

    # exp_string_model = exp_string
    # exp_string_autoencoder = exp_string
    resume_itr = 0
    model_file = None

    tf.global_variables_initializer().run()
    tf.train.start_queue_runners()

    if FLAGS.resume or not FLAGS.train:
        model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' + exp_string)
        autoencoder_file = tf.train.latest_checkpoint(FLAGS.logdir_autoencoder + '/' + exp_string)
        if FLAGS.test_iter > 0:
            model_file = model_file[:model_file.index('model')] + 'model' + str(FLAGS.test_iter)
        if model_file:
            ind1 = model_file.index('model')
            resume_itr = int(model_file[ind1+5:])
            print("Restoring model weights from " + model_file)
            saver.restore(sess, model_file)

            # ind1 = autoencoder_file.index('autoencoder')
            # resume_itr = int()
            print("Restoring autoencoder weights from " + autoencoder_file)
            w1 = sess.run(autoencoder.weights)
            saver.restore(sess, autoencoder_file)
            w2 = sess.run(autoencoder.weights)

    if FLAGS.train:
        print('training now')
        train(model, autoencoder, saver, sess, exp_string, data_generator, resume_itr)
    else:
        test(model, autoencoder, saver, sess, exp_string, data_generator, test_num_updates)
コード例 #22
0
def main():
    print('Train(0) or Test(1)?')
    train_ = input()
    train_count = 100
    if train_ == '0':
        FLAGS.train = True
        print('训练模式下的训练次数')
        train_count = input()
        FLAGS.metatrain_iterations = int(train_count)
    else:
        FLAGS.train = False

    print('选择GPU:')
    gpu_index = input()

    os.environ['CUDA_VISIBLE_DEVICES'] = gpu_index
    config_gpu = tf.ConfigProto()
    config_gpu.gpu_options.allow_growth = True

    if FLAGS.train is True:
        test_num_updates = 1
    else:
        test_num_updates = 10  # 源代码在测试时候是10次内部梯度下降

    if FLAGS.train is False:  # 测试
        orig_meta_batch_size = FLAGS.meta_batch_size
        # always use meta batch size of 1 when testing.
        FLAGS.meta_batch_size = 1
    print('main.py: 生成data_generator')
    if FLAGS.train:
        data_generator = DataGeneratorOneInstance(FLAGS.update_batch_size + 15,
                                                  FLAGS.meta_batch_size)
        # data_generator = DataGenerator_embedding(FLAGS.update_batch_size + 15, FLAGS.meta_batch_size)
    else:
        data_generator = DataGeneratorOneInstance(FLAGS.update_batch_size * 2,
                                                  FLAGS.meta_batch_size)
        # data_generator = DataGenerator_embedding(FLAGS.update_batch_size * 2, FLAGS.meta_batch_size)

    # 输出维度
    dim_output = data_generator.dim_output
    dim_input = data_generator.dim_input
    print('dim_input in main is {}'.format(dim_input))

    tf_data_load = True
    num_classes = data_generator.num_classes
    sess = tf.InteractiveSession(config=config_gpu)
    # sess = tf.InteractiveSession()

    if FLAGS.train:  # only construct training model if needed
        random.seed(5)
        '''
        关于image_tensor和label_tensor的说明
        return all_image_batches, all_label_batches
        all_images_batches:
        [batch1:[pic1, pic2, ...], batch2:[]...],其中pic:[0.1,0.08,...共84*84*3长]
        all_label_batches:
        [batch1:[  [[0,1,0..], [1,0,0..], []..]  ], batch2:[]...],其中[0,1,..]长为num_classes个
        '''
        # make_data_tensor
        print(
            'main.py: train: data_generator.make_data_tensor(),得到inputa等并进行切分')
        image_tensor, label_tensor = data_generator.make_data_tensor(
            train=True)
        inputa = tf.slice(image_tensor, [0, 0, 0],
                          [-1, num_classes * FLAGS.update_batch_size, -1])
        inputb = tf.slice(image_tensor,
                          [0, num_classes * FLAGS.update_batch_size, 0],
                          [-1, -1, -1])
        labela = tf.slice(label_tensor, [0, 0, 0],
                          [-1, num_classes * FLAGS.update_batch_size, -1])
        labelb = tf.slice(label_tensor,
                          [0, num_classes * FLAGS.update_batch_size, 0],
                          [-1, -1, -1])
        input_tensors = {
            'inputa': inputa,
            'inputb': inputb,
            'labela': labela,
            'labelb': labelb
        }
    # 用于生成验证数据集实时打印准确率
    random.seed(6)
    print('main.py: val: data_generator.make_data_tensor()')
    image_tensor, label_tensor = data_generator.make_data_tensor(
        train=False)  # train=False仅影响文件夹以及batch_count
    inputa = tf.slice(
        image_tensor, [0, 0, 0],
        [-1, num_classes * FLAGS.update_batch_size, -1])  # 0到5*4为input_a
    inputb = tf.slice(image_tensor,
                      [0, num_classes * FLAGS.update_batch_size, 0],
                      [-1, -1, -1])
    labela = tf.slice(label_tensor, [0, 0, 0],
                      [-1, num_classes * FLAGS.update_batch_size, -1])
    labelb = tf.slice(label_tensor,
                      [0, num_classes * FLAGS.update_batch_size, 0],
                      [-1, -1, -1])
    metaval_input_tensors = {
        'inputa': inputa,
        'inputb': inputb,
        'labela': labela,
        'labelb': labelb
    }

    print('model = MAML()')
    # test_num_updates: train:1, test:5,内部梯度下降数
    model = MAML(dim_input, dim_output, test_num_updates=test_num_updates)
    if FLAGS.train or not tf_data_load:
        # 初始化结束后必须调用 construct_model函数
        print('model.construct_model(\'metatrain_\')')
        model.construct_model(input_tensors=input_tensors, prefix='metatrain_')
    if tf_data_load:
        print('model.construct_model(\'metaval_\')')
        model.construct_model(input_tensors=metaval_input_tensors,
                              prefix='metaval_')
    model.summ_op = tf.summary.merge_all()
    saver = loader = tf.train.Saver(tf.get_collection(
        tf.GraphKeys.TRAINABLE_VARIABLES),
                                    max_to_keep=10)

    # 训练阶段
    if FLAGS.train is False:
        # 测试阶段使用原始的batch_size
        FLAGS.meta_batch_size = orig_meta_batch_size

    if FLAGS.train_update_batch_size == -1:
        FLAGS.train_update_batch_size = FLAGS.update_batch_size
    if FLAGS.train_update_lr == -1:
        FLAGS.train_update_lr = FLAGS.update_lr

    exp_string = 'cls_' + str(FLAGS.num_classes) + '.mbs_' + str(
        FLAGS.meta_batch_size) + '.ubs_' + str(
            FLAGS.train_update_batch_size) + '.numstep' + str(
                FLAGS.num_updates) + '.updatelr' + str(FLAGS.train_update_lr)

    if FLAGS.num_filters != 64:
        exp_string += 'hidden' + str(FLAGS.num_filters)
    if FLAGS.max_pool:
        exp_string += 'maxpool'
    if FLAGS.stop_grad:
        exp_string += 'stopgrad'
    if FLAGS.baseline:
        exp_string += FLAGS.baseline
    if FLAGS.norm == 'batch_norm':
        exp_string += 'batchnorm'
    elif FLAGS.norm == 'layer_norm':
        exp_string += 'layernorm'
    elif FLAGS.norm == 'None':
        exp_string += 'nonorm'
    else:
        print('Norm setting not recognized.')

    resume_itr = 0  # 断点继续训练
    model_file = None
    # 初始化变量
    tf.global_variables_initializer().run()
    tf.local_variables_initializer().run()
    tf.train.start_queue_runners()

    # cls_5.mbs_4.ubs_5.numstep5.updatelr0.01hidden32maxpoolbatchnorm
    if FLAGS.resume or not FLAGS.train:
        model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' +
                                                exp_string)
        if FLAGS.test_iter > 0:
            model_file = model_file[:model_file.index('model'
                                                      )] + 'model' + str(
                                                          FLAGS.test_iter)
        if model_file:
            ind1 = model_file.index('model')
            resume_itr = int(model_file[ind1 + 5:])
            print("读取已有训练数据Restoring model weights from " + model_file)
            saver.restore(sess, model_file)

    # if FLAGS.train:
    if FLAGS.train:
        print('main.py: 跳转到 train(model, saver, sess, exp_string...)...')
        # my(model, sess)
        train(model, saver, sess, exp_string, data_generator, resume_itr)
    else:
        print('main.py: 跳转到 _test(model, saver, sess, exp_string...)...')
        _test(model, saver, sess, exp_string, data_generator, test_num_updates)
コード例 #23
0
def main():
    sess = tf.InteractiveSession()
    if FLAGS.train:
        test_num_updates = FLAGS.num_updates
    else:
        test_num_updates = FLAGS.num_updates_test

    if FLAGS.train == False:
        orig_meta_batch_size = FLAGS.meta_batch_size
        FLAGS.meta_batch_size = 1

    if FLAGS.datasource in ['2D']:
        data_generator = DataGenerator(
            FLAGS.update_batch_size + FLAGS.update_batch_size_eval,
            FLAGS.meta_batch_size)
    else:
        if FLAGS.train:
            data_generator = DataGenerator(FLAGS.update_batch_size + 15,
                                           FLAGS.meta_batch_size)
        else:
            data_generator = DataGenerator(FLAGS.update_batch_size * 2,
                                           FLAGS.meta_batch_size)

    dim_output = data_generator.dim_output
    dim_input = data_generator.dim_input

    if FLAGS.datasource in ['plainmulti', 'artmulti']:
        num_classes = data_generator.num_classes
        if FLAGS.train:
            random.seed(5)
            if FLAGS.datasource == 'plainmulti':
                image_tensor, label_tensor = data_generator.make_data_tensor_plainmulti(
                )
            elif FLAGS.datasource == 'artmulti':
                image_tensor, label_tensor = data_generator.make_data_tensor_artmulti(
                )
            inputa = tf.slice(image_tensor, [0, 0, 0],
                              [-1, num_classes * FLAGS.update_batch_size, -1])
            inputb = tf.slice(image_tensor,
                              [0, num_classes * FLAGS.update_batch_size, 0],
                              [-1, -1, -1])
            labela = tf.slice(label_tensor, [0, 0, 0],
                              [-1, num_classes * FLAGS.update_batch_size, -1])
            labelb = tf.slice(label_tensor,
                              [0, num_classes * FLAGS.update_batch_size, 0],
                              [-1, -1, -1])
            input_tensors = {
                'inputa': inputa,
                'inputb': inputb,
                'labela': labela,
                'labelb': labelb
            }
        else:
            random.seed(6)
            if FLAGS.datasource == 'plainmulti':
                image_tensor, label_tensor = data_generator.make_data_tensor_plainmulti(
                    train=False)
            elif FLAGS.datasource == 'artmulti':
                image_tensor, label_tensor = data_generator.make_data_tensor_artmulti(
                    train=False)
            inputa = tf.slice(image_tensor, [0, 0, 0],
                              [-1, num_classes * FLAGS.update_batch_size, -1])
            inputb = tf.slice(image_tensor,
                              [0, num_classes * FLAGS.update_batch_size, 0],
                              [-1, -1, -1])
            labela = tf.slice(label_tensor, [0, 0, 0],
                              [-1, num_classes * FLAGS.update_batch_size, -1])
            labelb = tf.slice(label_tensor,
                              [0, num_classes * FLAGS.update_batch_size, 0],
                              [-1, -1, -1])
            metaval_input_tensors = {
                'inputa': inputa,
                'inputb': inputb,
                'labela': labela,
                'labelb': labelb
            }
    else:
        input_tensors = None
        metaval_input_tensors = None

    model = MAML(sess,
                 dim_input,
                 dim_output,
                 test_num_updates=test_num_updates)

    if FLAGS.train:
        model.construct_model(input_tensors=input_tensors, prefix='metatrain_')
    else:
        model.construct_model(input_tensors=metaval_input_tensors,
                              prefix='metaval_')

    saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES),
                           max_to_keep=60)

    if FLAGS.train == False:
        FLAGS.meta_batch_size = orig_meta_batch_size

    exp_string = 'cls_' + str(FLAGS.num_classes) + '.mbs_' + str(
        FLAGS.meta_batch_size) + '.ubs_' + str(
            FLAGS.update_batch_size) + '.numstep' + str(
                FLAGS.num_updates) + '.updatelr' + str(
                    FLAGS.update_lr) + '.metalr' + str(
                        FLAGS.meta_lr) + '.emb_loss_weight' + str(
                            FLAGS.emb_loss_weight) + '.hidden_dim' + str(
                                FLAGS.hidden_dim)

    if FLAGS.norm == 'batch_norm':
        exp_string += 'batchnorm'
    elif FLAGS.norm == 'None':
        exp_string += 'nonorm'

    tf.global_variables_initializer().run()
    tf.train.start_queue_runners()

    if FLAGS.resume or not FLAGS.train:
        model_file = '{0}/{2}/model{1}'.format(FLAGS.logdir, FLAGS.test_epoch,
                                               exp_string)
        if model_file:
            print("Restoring model weights from " + model_file)
            saver.restore(sess, model_file)
    resume_itr = 0

    if FLAGS.train:
        train(model, saver, sess, exp_string, data_generator, resume_itr)
    else:
        test(model, sess, data_generator)
コード例 #24
0
def main():
    if FLAGS.datasource == 'sinusoid':
        if FLAGS.train:
            test_num_updates = 5
        else:
            test_num_updates = 2
    else:
        if FLAGS.datasource == 'miniimagenet':
            if FLAGS.train == True:
                test_num_updates = 1  # eval on at least one update during training
            else:
                test_num_updates = 10
        else:
            test_num_updates = 10

    if FLAGS.train == False:
        orig_meta_batch_size = FLAGS.meta_batch_size
        # always use meta batch size of 1 when testing.
        FLAGS.meta_batch_size = 1

    if FLAGS.datasource == 'sinusoid':
        data_generator = DataGenerator(FLAGS.update_batch_size * 2,
                                       FLAGS.meta_batch_size)
    else:
        if FLAGS.metatrain_iterations == 0 and FLAGS.datasource == 'miniimagenet':
            assert FLAGS.meta_batch_size == 1
            assert FLAGS.update_batch_size == 1
            data_generator = DataGenerator(
                1, FLAGS.meta_batch_size)  # only use one datapoint,
        else:
            if FLAGS.datasource == 'miniimagenet':  # TODO - use 15 val examples for imagenet?
                if FLAGS.train:
                    data_generator = DataGenerator(
                        FLAGS.update_batch_size + 15, FLAGS.meta_batch_size
                    )  # only use one datapoint for testing to save memory
                else:
                    data_generator = DataGenerator(
                        FLAGS.update_batch_size * 2, FLAGS.meta_batch_size
                    )  # only use one datapoint for testing to save memory
            else:
                data_generator = DataGenerator(
                    FLAGS.update_batch_size * 2, FLAGS.meta_batch_size
                )  # only use one datapoint for testing to save memory

    dim_output = data_generator.dim_output
    if FLAGS.baseline == 'oracle':
        assert FLAGS.datasource == 'sinusoid'
        dim_input = 3
        FLAGS.pretrain_iterations += FLAGS.metatrain_iterations
        FLAGS.metatrain_iterations = 0
    else:
        dim_input = data_generator.dim_input

    if FLAGS.datasource == 'miniimagenet' or FLAGS.datasource == 'omniglot':
        tf_data_load = True
        num_classes = data_generator.num_classes

        if FLAGS.train:  # only construct training model if needed
            random.seed(5)
            image_tensor, label_tensor = data_generator.make_data_tensor()
            inputa = tf.slice(image_tensor, [0, 0, 0],
                              [-1, num_classes * FLAGS.update_batch_size, -1])
            inputb = tf.slice(image_tensor,
                              [0, num_classes * FLAGS.update_batch_size, 0],
                              [-1, -1, -1])
            labela = tf.slice(label_tensor, [0, 0, 0],
                              [-1, num_classes * FLAGS.update_batch_size, -1])
            labelb = tf.slice(label_tensor,
                              [0, num_classes * FLAGS.update_batch_size, 0],
                              [-1, -1, -1])
            input_tensors = {
                'inputa': inputa,
                'inputb': inputb,
                'labela': labela,
                'labelb': labelb
            }

        random.seed(6)
        image_tensor, label_tensor = data_generator.make_data_tensor(
            train=False)
        inputa = tf.slice(image_tensor, [0, 0, 0],
                          [-1, num_classes * FLAGS.update_batch_size, -1])
        inputb = tf.slice(image_tensor,
                          [0, num_classes * FLAGS.update_batch_size, 0],
                          [-1, -1, -1])
        labela = tf.slice(label_tensor, [0, 0, 0],
                          [-1, num_classes * FLAGS.update_batch_size, -1])
        labelb = tf.slice(label_tensor,
                          [0, num_classes * FLAGS.update_batch_size, 0],
                          [-1, -1, -1])
        metaval_input_tensors = {
            'inputa': inputa,
            'inputb': inputb,
            'labela': labela,
            'labelb': labelb
        }
    else:
        tf_data_load = False
        input_tensors = None

    model = MAML(dim_input, dim_output, test_num_updates=test_num_updates)
    if FLAGS.train or not tf_data_load:
        model.construct_model(input_tensors=input_tensors, prefix='metatrain_')
    if tf_data_load:
        model.construct_model(input_tensors=metaval_input_tensors,
                              prefix='metaval_')
    model.summ_op = tf.summary.merge_all()

    saver = loader = tf.train.Saver(max_to_keep=10)

    #saver = loader = tf.train.Saver(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES), max_to_keep=10)

    sess = tf.InteractiveSession()

    if FLAGS.train == False:
        # change to original meta batch size when loading model.
        FLAGS.meta_batch_size = orig_meta_batch_size

    if FLAGS.train_update_batch_size == -1:
        FLAGS.train_update_batch_size = FLAGS.update_batch_size
    if FLAGS.train_update_lr == -1:
        FLAGS.train_update_lr = FLAGS.update_lr

    exp_string = 'cls_' + str(FLAGS.num_classes) + '.mbs_' + str(
        FLAGS.meta_batch_size) + '.ubs_' + str(
            FLAGS.train_update_batch_size) + '.numstep' + str(
                FLAGS.num_updates) + '.updatelr' + str(FLAGS.train_update_lr)

    if FLAGS.num_filters != 64:
        exp_string += 'hidden' + str(FLAGS.num_filters)
    if FLAGS.max_pool:
        exp_string += 'maxpool'
    if FLAGS.stop_grad:
        exp_string += 'stopgrad'
    if FLAGS.baseline:
        exp_string += FLAGS.baseline
    if FLAGS.norm == 'batch_norm':
        exp_string += 'batchnorm'
    elif FLAGS.norm == 'layer_norm':
        exp_string += 'layernorm'
    elif FLAGS.norm == 'None':
        exp_string += 'nonorm'
    else:
        print('Norm setting not recognized.')

    resume_itr = 0
    model_file = None

    tf.global_variables_initializer().run()
    tf.train.start_queue_runners()

    if FLAGS.resume or not FLAGS.train:
        print("Seeing if resume....")
        print("File string: ", FLAGS.logdir + '/' + exp_string)
        model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' +
                                                exp_string)
        print("model file name: ", model_file)
        if FLAGS.test_iter > 0:
            model_file = model_file[:model_file.index('model'
                                                      )] + 'model' + str(
                                                          FLAGS.test_iter)
        if model_file:
            ind1 = model_file.index('model')
            resume_itr = int(model_file[ind1 + 5:])
            print("Restoring model weights from " + model_file)
            saver.restore(sess, model_file)

    if FLAGS.train:
        train(model, saver, sess, exp_string, data_generator, resume_itr)
    else:
        test(model, saver, sess, exp_string, data_generator, test_num_updates)
コード例 #25
0
def main():
    test_num_updates = 1
    if FLAGS.train == False:
        orig_meta_batch_size = FLAGS.meta_batch_size
        # always use meta batch size of 100 when testing.
        FLAGS.meta_batch_size = 100
    data_generator = DataGenerator(batch_size=FLAGS.meta_batch_size)

    dim_output = data_generator.dim_output
    dim_input = data_generator.dim_input
    num_classes = data_generator.num_classes

    if FLAGS.train:  # only construct training model if needed
        random.seed(5)
        image_tensor, label_tensor = data_generator.make_data_tensor()
        input_tensors = {'input': image_tensor, 'label': label_tensor}

    random.seed(6)
    image_tensor, label_tensor = data_generator.make_data_tensor(train=False)
    metaval_input_tensors = {'input': image_tensor, 'label': label_tensor}

    model = MAML(dim_input, dim_output, test_num_updates=test_num_updates)
    if FLAGS.vanilla:
        if FLAGS.train:
            model.construct_vanilla_model(input_tensors=input_tensors,
                                          prefix='metatrain_')
        model.construct_vanilla_model(input_tensors=metaval_input_tensors,
                                      prefix='metaval_')
    else:
        if FLAGS.train:
            model.construct_model(input_tensors=input_tensors,
                                  prefix='metatrain_')
        model.construct_model(input_tensors=metaval_input_tensors,
                              prefix='metaval_')
    model.summ_op = tf.summary.merge_all()
    saver = loader = tf.train.Saver(tf.global_variables(), max_to_keep=10)
    sess = tf.InteractiveSession()
    if FLAGS.train == False:
        # change to original meta batch size when loading model.
        FLAGS.meta_batch_size = orig_meta_batch_size

    exp_string = FLAGS.dataset \
                + '_backbone_' + FLAGS.backbone \
                + '_scalar_lr_' + str(FLAGS.scalar_lr) \
                + '_mbs_'+str(FLAGS.meta_batch_size) \
                + '.dict_' + str(FLAGS.dict_size) + '.numstep' \
                + str(FLAGS.num_updates) + '.updatelr' + str(FLAGS.update_lr) \
                + '.vanilla_' + str(FLAGS.vanilla) \
                + '.fix_v_' + str(FLAGS.fix_v) \
                + '.alpha_' + str(FLAGS.alpha) \

    if FLAGS.dropout_ratio != 0.5:
        exp_string += '_dropout_' + str(FLAGS.dropout_ratio)
    if FLAGS.vanilla and FLAGS.optimizer != 'sgd':
        exp_string += FLAGS.optimizer
    exp_string += '_weight_decay_' + str(FLAGS.weight_decay)
    if FLAGS.dot:
        exp_string += '_dot'
    if FLAGS.modulate in ['all', 'last', 'before_fc']:
        exp_string += '_modulate_' + FLAGS.modulate + '_size_' + str(
            FLAGS.film_dict_size)
    print(exp_string)
    resume_itr = 0
    model_file = None

    tf.global_variables_initializer().run()
    tf.train.start_queue_runners()
    prev_best_accu = 0
    if FLAGS.resume or not FLAGS.train:
        model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' +
                                                exp_string)
        if FLAGS.test_iter > 0:
            model_file = model_file[:model_file.index('model'
                                                      )] + 'model' + str(
                                                          FLAGS.test_iter)
        if model_file:
            ind1 = model_file.index('model')
            resume_itr = int(model_file[ind1 + 5:])
            print("Restoring model weights from " + model_file)
            loader.restore(sess, model_file)
            orig_train = FLAGS.train
            FLAGS.train = False
            if FLAGS.vanilla:
                prev_best_accu = test_vanilla(model, saver, sess, exp_string,
                                              data_generator)
            else:
                prev_best_accu = test(model, saver, sess, exp_string,
                                      data_generator)
            FLAGS.train = orig_train

    if FLAGS.vanilla:
        if FLAGS.train:
            train_vanilla(model, saver, sess, exp_string, data_generator,
                          prev_best_accu, resume_itr)
        else:
            test_vanilla(model, saver, sess, exp_string, data_generator)
    else:
        if FLAGS.train:
            train(model, saver, sess, exp_string, data_generator,
                  prev_best_accu, resume_itr)
        else:
            test(model, saver, sess, exp_string, data_generator)