Пример #1
0
def main():
    """
    Load data and train a model on it.
    """
    args = argument_parser().parse_args()
    random.seed(args.seed)

    train_set, val_set, test_set = read_dataset(DATA_DIR)
    model = MiniImageNetModel(args.classes, **model_kwargs(args))

    with tf.Session() as sess:
        if not args.pretrained:
            print('Training...')
            train(sess, model, train_set, test_set, args.checkpoint,
                  **train_kwargs(args))
        else:
            print('Restoring from checkpoint...')
            tf.train.Saver().restore(
                sess, tf.train.latest_checkpoint(args.checkpoint))

        print('Evaluating...')
        eval_kwargs = evaluate_kwargs(args)
        #print('Train accuracy: ' + str(evaluate(sess, model, train_set, **eval_kwargs)))
        #print('Validation accuracy: ' + str(evaluate(sess, model, val_set, **eval_kwargs)))
        print('Test accuracy: ' +
              str(evaluate(sess, model, test_set, **eval_kwargs)))
Пример #2
0
def main():
    """
    Load data and train a model on it.
    """
    args = argument_parser().parse_args()
    random.seed(args.seed)

    train_set, val_set, test_set = read_dataset(DATA_DIR)

    if args.org:
        model = MiniImageNetModel(args.classes, **model_kwargs(args))
        run_miniimagenet(model, args, train_set, val_set, test_set,
                         args.checkpoint_org, "MODEL WITHOUT REGULARIZATION")

    tf.reset_default_graph()

    if args.reg:
        reg_model = RegularizedMiniImageNetModel(args.classes,
                                                 **model_kwargs(args))
        run_miniimagenet(reg_model, args, train_set, val_set, test_set,
                         args.checkpoint_reg, "MODEL WITH REGULARIZATION")
def main():
    """
    Load data and train a model on it.
    """
    context = neptune.Context()
    context.integrate_with_tensorflow()
    final_train_channel = context.create_channel('final_train_accuracy', neptune.ChannelType.NUMERIC)
    final_val_channel = context.create_channel('final_val_accuracy', neptune.ChannelType.NUMERIC)
    final_test_channel = context.create_channel('final_test_accuracy', neptune.ChannelType.NUMERIC)
    args = neptune_args(context)
    print('args:', args)
    random.seed(args.seed)

    train_set, val_set, test_set = read_dataset(args.miniimagenet_src)
    model = ProgressiveMiniImageNetModel(args.classes, **model_kwargs(args))

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = args.allow_growth
    with tf.Session(config=config) as sess:
        if not args.pretrained:
            print('Training...')
            train(sess, model, train_set, test_set, args.checkpoint, **train_kwargs(args))
        else:
            print('Restoring from checkpoint...')
            tf.train.Saver().restore(sess, tf.train.latest_checkpoint(args.checkpoint))

        print('Evaluating...')
        eval_kwargs = evaluate_kwargs(args)

        final_train_accuracy = evaluate(sess, model, train_set, **eval_kwargs)
        print('final_train_accuracy:', final_train_accuracy)
        final_train_channel.send(final_train_accuracy)

        final_val_accuracy = evaluate(sess, model, val_set, **eval_kwargs)
        print('final_val_accuracy:', final_val_accuracy)
        final_val_channel.send(final_val_accuracy)

        final_test_accuracy = evaluate(sess, model, test_set, **eval_kwargs)
        print('final_test_accuracy:', final_test_accuracy)
        final_test_channel.send(final_test_accuracy)
Пример #4
0
def main():
    """
    Load data and train a model on it.
    """
    args = argument_parser().parse_args()
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)

    random.seed(args.seed)

    train_set, val_set, test_set = read_dataset(DATA_DIR)
    if args.metatransfer:
        model = MiniImageNetMetaTransferModel(args.classes,
                                              **model_kwargs(args))
    else:
        model = MiniImageNetModel(args.classes, **model_kwargs(args))
    config = tf.ConfigProto(allow_soft_placement=True,
                            log_device_placement=False)
    config.gpu_options.allow_growth = True

    with tf.Session(config=config) as sess:
        if not args.pretrained:
            print('Training...')
            train(sess, model, train_set, test_set, args.checkpoint,
                  **train_kwargs(args))
        else:
            print('Restoring from checkpoint...')
            tf.train.Saver().restore(
                sess, tf.train.latest_checkpoint(args.checkpoint))

        print('Evaluating...')
        eval_kwargs = evaluate_kwargs(args)
        #        print('Train accuracy: ' + str(evaluate(sess, model, train_set, **eval_kwargs)))
        #        print('Validation accuracy: ' + str(evaluate(sess, model, val_set, **eval_kwargs)))
        print('Test accuracy: ' +
              str(evaluate(sess, model, test_set, **eval_kwargs)))