示例#1
0
    config = load_config(args.config)
    train_gen = EpisodeGenerator(args.dataset_dir, 'train', config)
    test_gen = EpisodeGenerator(args.dataset_dir, 'test', config)
    if args.train:
        max_iter = train_gen.dataset_size[args.dataset_name] * args.max_epoch \
                // (nway * qsize)
        show_step = args.show_epoch * max_iter // args.max_epoch
        save_step = args.save_epoch * max_iter // args.max_epoch
        avger = np.zeros([4])
        for i in range(1, max_iter + 1):
            stt = time.time()
            cur_epoch = i * (
                nway * qsize) // train_gen.dataset_size[args.dataset_name]
            lr = args.lr if i < 0.7 * max_iter else args.lr * .1
            sx, sy, qx, qy = train_gen.get_episode(nway, kshot, qsize)
            fd = {\
                protonet.inputs['sx']: sx,
                protonet.inputs['qx']: qx,
                protonet.inputs['qy']: qy,
                lr_ph: lr}
            p1, p2, _ = sess.run([acc, loss, train_op], fd)
            avger += [p1, p2, 0, time.time() - stt]

            if i % show_step == 0 and i != 0:
                avger /= show_step
                print ('========= epoch : {:8d}/{} ========='\
                        .format(cur_epoch, args.max_epoch))
                print ('Training - ACC: {:.3f} '
                    '| LOSS: {:.3f}   '
                    '| lr : {:.3f}    '
示例#2
0
    config = load_config(args.config)
    train_gen = EpisodeGenerator(args.dataset_dir, 'train', config)
    #test_gen = EpisodeGenerator(args.dataset_dir, 'test', config)
    if args.train:
        max_iter = train_gen.dataset_size[args.dataset_name] * args.max_epoch \
                // (nway * qsize)
        show_step = args.show_epoch * max_iter // args.max_epoch
        save_step = args.save_epoch * max_iter // args.max_epoch
        avger = np.zeros([4])
        for i in range(1, max_iter + 1):
            stt = time.time()
            cur_epoch = i * (
                nway * qsize) // train_gen.dataset_size[args.dataset_name]
            lr = args.lr if i < 0.7 * max_iter else args.lr * .1
            sx, sy, qx, qy = train_gen.get_episode(
                nway, kshot, qsize, dataset_name=args.dataset_name)
            fd = {\
                protonet.inputs['sx']: sx,
                protonet.inputs['qx']: qx,
                protonet.inputs['qy']: qy,
                lr_ph: lr}
            p1, p2, _ = sess.run([acc, loss, train_op], fd)
            avger += [p1, p2, 0, time.time() - stt]

            if i % show_step == 0 and i != 0:
                avger /= show_step
                print ('========= {:15s} epoch : {:8d}/{} ========='\
                        .format(args.dataset_name, cur_epoch, args.max_epoch))
                print ('Training - ACC: {:.3f} '
                    '| LOSS: {:.3f}   '
                    '| lr : {:.3f}    '
示例#3
0
    load_loc = os.path.join(args.model_dir, 
            args.project_name, 
            args.ens_name + '.ckpt')
    saver.restore(sess, load_loc)
    
    ep_gen = EpisodeGenerator(args.dataset_dir, 'test', config)
    target_dataset = ep_gen.dataset_list
    #target_dataset = ['miniImagenet']
    means = np.zeros([len(TRAIN_DATASET)+2, len(target_dataset)])
    stds  = np.zeros([len(TRAIN_DATASET)+2, len(target_dataset)])
    for tdi, dset in enumerate(target_dataset):
        print ('==========TARGET : {}=========='.format(dset)) 
        temp_results = [[] for _ in range(len(TRAIN_DATASET)+2)]
        for i in range(args.max_iter):
            sx, sy, qx, qy = ep_gen.get_episode(nway, kshot, qsize,
                    dataset_name=dset)
            fd = {sx_ph: sx, qx_ph: qx, qy_ph: qy}
            ps, p_acc, p_W = sess.run([preds, acc, pemb], fd)
            prediction = np.argmax(ps, axis=2) # (5, 150)

            for pn, p in enumerate(ps): 
                temp_p = np.argmax(p, 1)
                temp_results[pn].append(np.mean(temp_p == np.argmax(qy, 1)))
            temp_results[pn+1].append(np.mean(np.argmax(np.mean(ps,0),1)\
                    == np.argmax(qy,1)))
            temp_results[pn+2].append(p_acc)

        for i in range(len(TRAIN_DATASET)):
            print ('model trained on {:10s} - Acc {:.3f}'.format(\
                    TRAIN_DATASET[i], np.mean(temp_results[i])))
            means[i, tdi] = np.mean(temp_results[i])
示例#4
0
    vlist = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
    for i, v in enumerate(vlist):
        print(i, v)

    if args.train:
        avger = np.zeros([4])
        for i in range(1, args.max_iter + 1):
            stt = time.time()
            lr = args.meta_lr if i < 0.7 * args.max_iter else args.meta_lr * .1

            sx = []
            sy = []
            qx = []
            qy = []
            for _ in range(args.meta_batch_size):
                sxt, syt, qxt, qyt = ep_train.get_episode(
                    nway, test_kshot, qsize)
                sx.append(sxt)
                sy.append(syt)
                qx.append(qxt)
                qy.append(qyt)

            # qx : (metabatchsize, nk, img_size)
            # sx : (metabatchsize, nq, img_size)
            ip = mamlnet.inputs
            feed_dict = {
                ip['sx']: sx,
                ip['sy']: sy,
                ip['qx']: qx,
                ip['qy']: qy,
                ip['lr_alpha']: args.inner_lr,
                ip['lr_beta']: lr
示例#5
0
    lr = args.lr if i < 0.7 * max_iter else args.lr * .1
    sx, sy, qx, qy = ep_queue.get()
    fd = {sx_ph: sx, qx_ph: qx, qy_ph: qy, lr_ph: lr, isTr: True}
    p1, p2, _ = sess.run([acc, loss, train_op], fd)
    avger += [p1, p2, 0, time.time() - stt]

    if i % show_step == 0 and i != 0:
        avger /= show_step
        print('epoch : {:8d}/{}  ACC: {:.3f}   |   LOSS: {:.3f}   |  lr : {:.3f}   | in {:.2f} secs'\
                .format(cur_epoch, args.max_epoch, avger[0],
                    avger[1], lr, avger[3]*show_step))
        avger[:] = 0

    if i % save_step == 0 and i != 0:
        out_loc = os.path.join(
            args.model_dir,  # models/
            args.model_name,  # bline/
            'all_datasets.ckpt')  # cifar100.ckpt
        model.save(saver, sess, out_loc)

        for test_set_name in TEST_DATASETS:
            test_accs = []
            for j in range(TEST_NUM_EPISODES):
                sx_t, sy_t, qx_t, qy_t = ep_gen_test.get_episode(
                    nway, kshot, qsize, test_set_name)
                fd = {sx_ph: sx_t, qx_ph: qx_t, qy_ph: qy_t, isTr: False}
                acc_ = sess.run(acc, fd)
                test_accs.append(acc_)
            print('datset name: {}, test_acc: {:.3f} {:.5f}'.format(
                test_set_name, np.mean(test_accs), np.std(test_accs)))
示例#6
0
            loaders[tn].restore(sess, saved_loc)
            print('model_{} restored from {}'.format(tn, saved_loc))
    else:
        out_loc = os.path.join(args.model_dir, args.project_name,
                               args.ens_name + '.ckpt')
        saver.restore(sess, out_loc)

    ep_gen = EpisodeGenerator(args.dataset_dir, 'val', config)
    ep_test = EpisodeGenerator(args.dataset_dir, 'test', config)
    avger = np.zeros([4])
    np.set_printoptions(precision=3, suppress=False)
    for i in range(1, args.max_iter + 1):
        stt = time.time()
        lr = args.lr if i < 0.7 * args.max_iter else args.lr * .1
        sx, sy, qx, qy = ep_gen.get_episode(nway,
                                            kshot,
                                            qsize,
                                            printname=False)
        bedt = time.time() - stt
        fd = {sx_ph: sx, qx_ph: qx, qy_ph: qy, lr_ph: lr}

        loss_val, acc_val, W_val, _ = sess.run([loss, acc, fe, train_op], fd)
        avger += [loss_val, acc_val, bedt, time.time() - stt]

        if i % args.show_step == 0 and i != 0:
            avger /= args.show_step
            print ('step : {:8d}/{}  loss: {:.3f}   |   acc: {:.3f}  | batchtime: {:.2f}  | in {:.2f} secs'\
                    .format(i, args.max_iter, avger[0],
                        avger[1], avger[2]*args.show_step, avger[3]*args.show_step))
            avger[:] = 0

#            print ('+'*30)