Ejemplo n.º 1
0
def test_operate(way, ns, nq, path, eval_stats, ob_domain=False, num_domain=30):
    set_seed(0)
    if ob_domain:
        ns = nq = num_domain
        running_params['train_split'] = num_domain * 4

    model = DASMNLearner(n_way=way, n_support=ns, n_query=nq)

    # CW: NC, IF, OF, RoF
    # src_tasks, _ = generator.CW_10way(way=way, order=Load[0], examples=200, split=running_params['train_split'],
    #                                   normalize=True, data_len=DIM, label=False)
    # _, test_tasks = generator.CW_10way(way=way, order=Load[1], examples=200, split=0, normalize=True,
    #                                    data_len=DIM, label=False)

    # CW2SQ: NC, IF, OF
    # src_tasks, _ = generator.CW_cross(way=way, examples=100, split=running_params['train_split'],
    #                                   normalize=True, data_len=DIM, label=False, tgt_set='sq')
    # _, test_tasks = generator.SQ_37way(examples=100, split=0, way=way, normalize=True,
    #                                    label=False, data_len=DIM)
    # _, test_tasks = generator.EB_3_13way(examples=running_params['test_samples'],
    #                                      split=running_params['test_split'],
    #                                      way=way, order=3, normalize=True, label=False, data_len=DIM)

    # CW2SA: NC, OF, RoF(Ball)
    # src_tasks, _ = generator.CW_cross(way=way, examples=100, split=running_params['train_split'], data_len=DIM,
    #                                   normalize=True, label=False, tgt_set='sa')
    # _, test_tasks = generator.SA_37way(examples=200, split=0, way=way, normalize=True,
    #                                    label=False, data_len=DIM)

    src_tasks, test_tasks = generator.SA_37way(examples=200, split=30, way=way, normalize=True,
                                       label=False, data_len=DIM)


    # src_tasks = src_tasks if ob_domain else None
    src_tasks = src_tasks if ob_domain else None
    running_params['test_episodes'] = 10 if ob_domain else 100
    # if you do not want to observe the src and tgt results together
    print('test_task shape: ', test_tasks.shape)
    if ob_domain:
        print('src_task shape: ', src_tasks.shape)

    if eval_stats == 'yes':
        model.load(path)
        model.test(test_tasks, src_tasks, model_eval=True)

    elif eval_stats == 'both':
        model.load(path)
        model.test(test_tasks, src_tasks, model_eval=True)
        print('\n================Reloading the file====================')
        model.load(path)
        # Attention! Model.train() would change the trained weights(it's an invalid operation),
        # so we have to reload the trained file again.
        model.test(test_tasks, src_tasks, model_eval=False)

    elif eval_stats == 'no':
        model.load(path)
        model.test(test_tasks, src_tasks, model_eval=False)
Ejemplo n.º 2
0
def train_operate(way, ns, nq, save_path, final_test=True, load_path=None):
    set_seed(0)
    nets = DASMNLearner(n_way=way, n_support=ns, n_query=nq)
    if load_path is not None:  # 若加载路径不为空,则默认:模型微调
        nets.load(load_path)

    # CW: NC, IF, OF, RoF
    # src, _ = generator.CW_10way(way=way, order=Load[0], examples=200, split=running_params['train_split'],
    #                             normalize=True, data_len=DIM, label=False)
    # _, tar = generator.CW_10way(way=way, order=Load[1], examples=200, split=0, normalize=True,
    #                             data_len=DIM, label=False)

    # CW2SQ: NC, IF, OF
    src, _ = generator.CW_cross(way=way,
                                examples=100,
                                split=running_params['train_split'],
                                data_len=DIM,
                                normalize=True,
                                label=False,
                                tgt_set='sq')
    _, tar = generator.SQ_37way(examples=100,
                                split=0,
                                way=way,
                                label=False,
                                data_len=DIM,
                                normalize=True)

    # src, tar = generator.EB_3_13way(examples=running_params['train_samples'],
    #                                 split=running_params['train_split'],
    #                                 way=way, order=3, normalize=True, label=False)

    # CW2SA: NC, OF, RoF(Ball)
    # src, _ = generator.CW_cross(way=way, examples=100, split=running_params['train_split'],
    #                             normalize=True, data_len=DIM, label=False, tgt_set='sa')
    # _, tar = generator.SA_37way(examples=200, split=0, way=way, data_len=DIM,
    #                             normalize=True, label=False)

    # ---------------self testing on SA----------------------------------------------------
    # src, tar = generator.SA_37way(examples=200, split=running_params['train_split'],
    #                               way=way, normalize=True, label=False, overlap=True)

    # print('Train proto_1optimizer\n')  # 推荐
    # nets.train_proto_1op(src, tar)  # turn on the GRL

    # training 1:
    print('Train joint_training')  # 推荐 train with 2 optimizers
    nets.joint_training_2op(src, tar, save_path)  # turn on the GRL

    # training 2: Turn off the GRL !!!!!!!
    # print('Train proto_2loss')
    # nets.train_proto_2loss(src, tar)

    if final_test:
        print('We test the model!')
        nets.test(tar_tasks=tar, src_tasks=None, model_eval=True)
        nets.test(tar_tasks=tar, src_tasks=None, model_eval=False)