コード例 #1
0
ファイル: run_omniglot.py プロジェクト: taohong08/FMRL
def main():
    """
    Load data and train a model on it.
    """
    args = argument_parser().parse_args()
    random.seed(args.seed)

    train_set, test_set = split_dataset(read_dataset(DATA_DIR))
    train_set = list(augment_dataset(train_set))
    test_set = list(test_set)

    model = OmniglotModel(args.classes, **model_kwargs(args))

    with tf.Session() as sess:
        eval_kwargs = evaluate_kwargs(args)
        if not args.pretrained:
            print('Training...')
            train(sess,
                  model,
                  train_set,
                  test_set,
                  args.checkpoint,
                  eval_kwargs=eval_kwargs,
                  **train_kwargs(args))
        else:
            print('Restoring from checkpoint...')
            tf.train.Saver().restore(
                sess, tf.train.latest_checkpoint(args.checkpoint))

        print('Evaluating...')
        eval_kwargs = evaluate_kwargs(args)
        print('Train accuracy: ' +
              str(evaluate(sess, model, train_set, **eval_kwargs)))
        print('Test accuracy: ' +
              str(evaluate(sess, model, test_set, **eval_kwargs)))
コード例 #2
0
def main():
    """
    Load data and train a model on it.
    """
    args = argument_parser().parse_args()
    random.seed(args.seed)

    train_set, val_set, test_set = read_dataset(DATA_DIR)
    model = MiniImageNetModel(args.classes, **model_kwargs(args))

    with tf.Session() as sess:
        if not args.pretrained:
            print('Training...')
            train(sess, model, train_set, test_set, args.checkpoint,
                  **train_kwargs(args))
        else:
            print('Restoring from checkpoint...')
            tf.train.Saver().restore(
                sess, tf.train.latest_checkpoint(args.checkpoint))

        print('Evaluating...')
        eval_kwargs = evaluate_kwargs(args)
        #print('Train accuracy: ' + str(evaluate(sess, model, train_set, **eval_kwargs)))
        #print('Validation accuracy: ' + str(evaluate(sess, model, val_set, **eval_kwargs)))
        print('Test accuracy: ' +
              str(evaluate(sess, model, test_set, **eval_kwargs)))
コード例 #3
0
def main():
    """
    Load data and train a model on it.
    """
    args = argument_parser().parse_args()
    random.seed(args.seed)

    train_set, test_set = split_dataset(read_dataset(DATA_DIR))
    train_set = list(augment_dataset(train_set))
    test_set = list(test_set)

    model = OmniglotModel(args.classes, **model_kwargs(args))

    with tf.Session() as sess:
        if not args.pretrained:
            print('Training...')
            train(sess, model, train_set, test_set, args.checkpoint, **train_kwargs(args))
        else:
            print('Restoring from checkpoint...')
            tf.train.Saver().restore(sess, tf.train.latest_checkpoint(args.checkpoint))

        print('Evaluating...')
        eval_kwargs = evaluate_kwargs(args)
        print('Train accuracy: ' + str(evaluate(sess, model, train_set, **eval_kwargs)))
        print('Test accuracy: ' + str(evaluate(sess, model, test_set, **eval_kwargs)))
コード例 #4
0
def main():
    """
    Load data and train a model on it.
    """
    args = argument_parser().parse_args()
    random.seed(args.seed)

    train_set, test_set = split_dataset(read_dataset(DATA_DIR))
    train_set = list(augment_dataset(train_set))
    test_set = list(test_set)

    model = OmniglotModel(args.classes, **model_kwargs(args))

    with tf.Session() as sess:
        resume_itr = 0 # Zero iterations have already been trained.

        if args.pretrained or args.test: # It must be pretrained to test it.
            print('Restoring from checkpoint...')
            saved_model = tf.train.latest_checkpoint(args.checkpoint)
            tf.train.Saver().restore(sess, saved_model)

            resume_itr = int(saved_model[saved_model.index('model.ckpt') + 11:])

        if not args.test:
            print('Training...')
            train(sess, model, train_set, test_set, args.checkpoint, resume_itr, **train_kwargs(args))

        print('Evaluating...')
        eval_kwargs = evaluate_kwargs(args)
        print('Train accuracy: ' + str(evaluate(sess, model, train_set, **eval_kwargs)))
        print('Test accuracy: ' + str(evaluate(sess, model, test_set, **eval_kwargs)))
コード例 #5
0
def main():
    """
    Load data and train a model on it.
    """
    args = argument_parser().parse_args()
    random.seed(args.seed)

    # train_set, val_set, test_set = read_dataset(DATA_DIR)
    # load pkl dataset here
    n_examples = 600
    n_episodes = 100
    args_data = {}
    args_data['x_dim'] = '84,84,3'
    args_data['ratio'] = 1.0
    args_data['seed'] = 1000
    train_set = dataset_mini(n_examples, n_episodes, 'train', args_data)
    val_set = dataset_mini(n_examples, n_episodes, 'val', args_data)
    test_set = dataset_mini(n_examples, n_episodes, 'test', args_data)
    train_set.load_data_pkl()
    val_set.load_data_pkl()
    test_set.load_data_pkl()

    model = MiniImageNetModel(args.classes, **model_kwargs(args))

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        if not args.pretrained:
            print('Training...')
            train(sess, model, train_set, test_set, args.checkpoint,
                  **train_kwargs(args))
        else:
            print('Restoring from checkpoint...')
            try:
                print(tf.train.latest_checkpoint(args.checkpoint))
                tf.train.Saver().restore(
                    sess, tf.train.latest_checkpoint(args.checkpoint))
            except:
                print(args.checkpoint)
                tf.train.Saver().restore(sess, args.checkpoint)

        print('Evaluating...')
        eval_kwargs = evaluate_kwargs(args)
        #print('Train accuracy: ' + str(evaluate(sess, model, train_set, **eval_kwargs)))
        #print('Validation accuracy: ' + str(evaluate(sess, model, val_set, **eval_kwargs)))
        print('Test accuracy: ' +
              str(evaluate(sess, model, test_set, **eval_kwargs)))
コード例 #6
0
def main():
    """
    Load data and train a model on it.
    """
    context = neptune.Context()
    context.integrate_with_tensorflow()
    final_train_channel = context.create_channel('final_train_accuracy',
                                                 neptune.ChannelType.NUMERIC)
    final_test_channel = context.create_channel('final_test_accuracy',
                                                neptune.ChannelType.NUMERIC)
    args = neptune_args(context)
    print('args:\n', args)
    random.seed(args.seed)

    train_set, test_set = split_dataset(read_dataset(args.omniglot_src))
    train_set = list(augment_dataset(train_set))
    test_set = list(test_set)

    model = ProgressiveOmniglotModel(args.classes, **model_kwargs(args))

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = args.allow_growth
    with tf.Session(config=config) as sess:
        if not args.pretrained:
            print('Training...')
            train(sess, model, train_set, test_set, args.checkpoint,
                  **train_kwargs(args))
        else:
            print('Restoring from checkpoint...')
            tf.train.Saver().restore(
                sess, tf.train.latest_checkpoint(args.checkpoint))

        print('Evaluating...')
        eval_kwargs = evaluate_kwargs(args)

        final_train_accuracy = evaluate(sess, model, train_set, **eval_kwargs)
        print('final_train_accuracy:', final_train_accuracy)
        final_train_channel.send(final_train_accuracy)

        final_test_accuracy = evaluate(sess, model, test_set, **eval_kwargs)
        print('final_test_accuracy:', final_test_accuracy)
        final_test_channel.send(final_test_accuracy)
コード例 #7
0
def main():
    """
    Load data and train a model on it.
    """
    args = argument_parser().parse_args()
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)

    random.seed(args.seed)

    train_set, val_set, test_set = read_dataset(DATA_DIR)
    if args.metatransfer:
        model = MiniImageNetMetaTransferModel(args.classes,
                                              **model_kwargs(args))
    else:
        model = MiniImageNetModel(args.classes, **model_kwargs(args))
    config = tf.ConfigProto(allow_soft_placement=True,
                            log_device_placement=False)
    config.gpu_options.allow_growth = True

    with tf.Session(config=config) as sess:
        if not args.pretrained:
            print('Training...')
            train(sess, model, train_set, test_set, args.checkpoint,
                  **train_kwargs(args))
        else:
            print('Restoring from checkpoint...')
            tf.train.Saver().restore(
                sess, tf.train.latest_checkpoint(args.checkpoint))

        print('Evaluating...')
        eval_kwargs = evaluate_kwargs(args)
        #        print('Train accuracy: ' + str(evaluate(sess, model, train_set, **eval_kwargs)))
        #        print('Validation accuracy: ' + str(evaluate(sess, model, val_set, **eval_kwargs)))
        print('Test accuracy: ' +
              str(evaluate(sess, model, test_set, **eval_kwargs)))
コード例 #8
0
def run_miniimagenet(model, args, train_set, val_set, test_set, checkpoint,
                     title):

    print("\nTraining phase of " + title + "\n")

    with tf.Session() as sess:
        if not args.pretrained:
            print('Training...')
            train(sess, model, train_set, test_set, checkpoint,
                  **train_kwargs(args))
        else:
            print('Restoring from checkpoint...')
            tf.train.Saver().restore(sess,
                                     tf.train.latest_checkpoint(checkpoint))

        print("\nEvaluation phase of " + title + "\n")
        print('Evaluating...')
        eval_kwargs = evaluate_kwargs(args)
        print('Train accuracy: ' +
              str(evaluate(sess, model, train_set, **eval_kwargs)))
        print('Validation accuracy: ' +
              str(evaluate(sess, model, val_set, **eval_kwargs)))
        print('Test accuracy: ' +
              str(evaluate(sess, model, test_set, **eval_kwargs)))
コード例 #9
0
ファイル: train.py プロジェクト: GMonster0905/ARUBA
def main():

    parser = argument_parser()
    parser.add_argument('--dataset',
                        help='which dataset to use',
                        default='omniglot',
                        type=str)
    parser.add_argument('--trials',
                        help='number of seeds',
                        default=3,
                        type=int)
    parser.add_argument('--val-samples',
                        help='number of validation samples',
                        default=100,
                        type=int)
    parser.add_argument('--restore',
                        help='restore from final training checkpoint',
                        action='store_true')

    args = parser.parse_args()

    os.makedirs(args.checkpoint, exist_ok=True)
    with open(os.path.join(args.checkpoint, 'args.json'), 'w') as f:
        json.dump(vars(args), f, indent=4)

    random.seed(args.seed)
    if args.dataset == 'omniglot':
        train_set, test_set = split_dataset(
            read_dataset(DATA_DIR[args.dataset]))
        trainval_set = list(augment_dataset(train_set))
        val_set = list(train_set[-200:])
        train_set = list(augment_dataset(train_set[:-200]))
        test_set = list(test_set)
    else:
        train_set, val_set, test_set = mini_dataset(DATA_DIR[args.dataset])
        trainval_set = train_set

    print('Training...')
    modes = ['train', 'test']
    metrics = ['acc', 'loss']
    try:
        with open(os.path.join(args.checkpoint, 'results.pkl'), 'rb') as f:
            results = pickle.load(f)
    except FileNotFoundError:
        results = {
            'regular':
            {mode: {metric: {}
                    for metric in metrics}
             for mode in modes},
            'trials': set()
        }
        if args.transductive:
            results['transductive'] = {
                mode: {metric: {}
                       for metric in metrics}
                for mode in modes
            }
        if args.adaptive:
            results['adaptive'] = {
                mode: {metric: {}
                       for metric in metrics}
                for mode in modes
            }
            if args.transductive:
                results['adaptive_transductive'] = {
                    mode: {metric: {}
                           for metric in metrics}
                    for mode in modes
                }

    for j in range(args.trials):
        if j in results['trials']:
            print('skipping trial', j)
            continue
        tf.reset_default_graph()
        model = OmniglotModel(args.classes, **model_kwargs(
            args)) if args.dataset == 'omniglot' else MiniImageNetModel(
                args.classes, **model_kwargs(args))
        with tf.Session() as sess:
            checkpoint = os.path.join(args.checkpoint, 'final' + str(j))
            if args.restore:
                print('Trial', j, 'Restoring...')
                tf.train.Saver().restore(
                    sess, tf.train.latest_checkpoint(checkpoint))
                args.restore = False
            else:
                os.makedirs(checkpoint, exist_ok=True)
                print('Trial', j, 'Training...')
                train(sess, model, trainval_set, test_set, checkpoint,
                      **train_kwargs(args))
            print('Trial', j, 'Evaluating...')
            for ev in results.keys():
                if ev == 'trials':
                    continue
                evkw = evaluate_kwargs(args)
                evkw['transductive'] = 'transductive' in ev
                evkw['adaptive'] = args.adaptive if 'adaptive' in ev else 0.0
                for mode, dset in zip(modes, [trainval_set, test_set]):
                    results[ev][mode]['acc'][j], results[ev][mode]['loss'][
                        j] = evaluate(sess, model, dset, **evkw)
            results['trials'].add(j)
        with open(os.path.join(args.checkpoint, 'results.pkl'), 'wb') as f:
            pickle.dump(results, f)
        with open(os.path.join(args.checkpoint, 'results.json'), 'w') as f:
            tr = results.pop('trials')
            json.dump(results, f, indent=4)
            results['trials'] = tr

    print('Train Acc:',
          sum(results['regular']['train']['acc'].values()) / args.trials)
    print('Test Acc:',
          sum(results['regular']['test']['acc'].values()) / args.trials)
    if args.transductive:
        print(
            'Transductive Train Acc:',
            sum(results['transductive']['train']['acc'].values()) /
            args.trials)
        print(
            'Transductive Test Acc:',
            sum(results['transductive']['test']['acc'].values()) / args.trials)
    if args.adaptive:
        print('Adaptive Train Acc:',
              sum(results['adaptive']['train']['acc'].values()) / args.trials)
        print('Adaptive Test Acc:',
              sum(results['adaptive']['test']['acc'].values()) / args.trials)
        if args.transductive:
            print(
                'Adaptive Transductive Train Acc:',
                sum(results['adaptive_transductive']['train']['acc'].values())
                / args.trials)
            print(
                'Adaptive Transductive Test Acc:',
                sum(results['adaptive_transductive']['test']['acc'].values()) /
                args.trials)