예제 #1
0
def test(args):
    # ---------- load model_real_cartoon ---------- #

    rc_e1 = E1(args.sep, int((args.resize / 64)))
    rc_e2 = E2(args.sep, int((args.resize / 64)))
    rc_decoder = Decoder(int((args.resize / 64)))

    if torch.cuda.is_available():
        rc_e1 = rc_e1.cuda()
        rc_e2 = rc_e2.cuda()
        rc_decoder = rc_decoder.cuda()

    if args.load_rc != '':
        save_file = os.path.join(args.load_rc)
        load_model_for_eval(save_file, rc_e1, rc_e2, rc_decoder)

    rc_e1 = rc_e1.eval()
    rc_e2 = rc_e2.eval()
    rc_decoder = rc_decoder.eval()

    # ---------- load model_cartoon ---------- #

    c_e1 = E1(args.sep, int((args.resize / 64)))
    c_e2 = E2(args.sep, int((args.resize / 64)))
    c_decoder = Decoder(int((args.resize / 64)))

    if torch.cuda.is_available():
        c_e1 = c_e1.cuda()
        c_e2 = c_e2.cuda()
        c_decoder = c_decoder.cuda()

    if args.load_c != '':
        save_file = os.path.join(args.load_c)
        load_model_for_eval(save_file, c_e1, c_e2, c_decoder)

    c_e1 = c_e1.eval()
    c_e2 = c_e2.eval()
    c_decoder = c_decoder.eval()

    # -------------- running -------------- #

    if not os.path.exists(args.out) and args.out != "":
        os.mkdir(args.out)


#     trans(args, rc_e1, rc_e2, rc_decoder, c_e1, c_e2, c_decoder)
    test_domA_cluster, test_domB_cluster = my_get_test_imgs(args)
    for idx, (test_domA, test_domB) in enumerate(
            list(zip(test_domA_cluster, test_domB_cluster))):
        trans(args, idx, test_domA, test_domB, rc_e1, rc_e2, rc_decoder, c_e1,
              c_e2, c_decoder)
def eval(args):
    e_common = E_common(args.sep, int((args.resize / 64)))
    e_separate_A = E_separate_A(args.sep, int((args.resize / 64)))
    e_separate_B = E_separate_B(args.sep, int((args.resize / 64)))
    decoder = Decoder(int((args.resize / 64)))

    if torch.cuda.is_available():
        e_common = e_common.cuda()
        e_separate_A = e_separate_A.cuda()
        e_separate_B = e_separate_B.cuda()
        decoder = decoder.cuda()

    if args.load != '':
        save_file = os.path.join(args.load, 'checkpoint')
        _iter = load_model_for_eval(save_file, e_common, e_separate_A,
                                    e_separate_B, decoder)

    e_common = e_common.eval()
    e_separate_A = e_separate_A.eval()
    e_separate_B = e_separate_B.eval()
    decoder = decoder.eval()

    if not os.path.exists(args.out) and args.out != "":
        os.mkdir(args.out)

    save_chosen_imgs(args, e_common, e_separate_A, e_separate_B, decoder,
                     _iter, [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], False)
예제 #3
0
def test(args):
    # ---------- load model_real_cartoon ---------- #
    
    rc_e1 = E1(args.sep, int((args.resize / 64)))
    rc_e2 = E2(args.sep, int((args.resize / 64)))
    rc_decoder = Decoder(int((args.resize / 64)))

    if torch.cuda.is_available():
        rc_e1 = rc_e1.cuda()
        rc_e2 = rc_e2.cuda()
        rc_decoder = rc_decoder.cuda()

    if args.load_rc != '':
        save_file = os.path.join(args.load_rc)
        load_model_for_eval(save_file, rc_e1, rc_e2, rc_decoder)

    rc_e1 = rc_e1.eval()
    rc_e2 = rc_e2.eval()
    rc_decoder = rc_decoder.eval()
    
    # ---------- load model_cartoon ---------- #
    
    c_e1 = E1(args.sep, int((args.resize / 64)))
    c_e2 = E2(args.sep, int((args.resize / 64)))
    c_decoder = Decoder(int((args.resize / 64)))

    if torch.cuda.is_available():
        c_e1 = c_e1.cuda()
        c_e2 = c_e2.cuda()
        c_decoder = c_decoder.cuda()

    if args.load_c != '':
        save_file = os.path.join(args.load_c)
        load_model_for_eval(save_file, c_e1, c_e2, c_decoder)

    c_e1 = c_e1.eval()
    c_e2 = c_e2.eval()
    c_decoder = c_decoder.eval()
    
    # -------------- running -------------- #
    
    if not os.path.exists(args.out) and args.out != "":
        os.mkdir(args.out)

    trans(args, rc_e1, rc_e2, rc_decoder, c_e1, c_e2, c_decoder)
예제 #4
0
def get_eval_model(load, sep, resize):
    e1 = E1(sep, int((resize / 64)))
    e2 = E2(sep, int((resize / 64)))
    decoder = Decoder(int((resize / 64)))
    
    if torch.cuda.is_available():
        e1 = e1.cuda()
        e2 = e2.cuda()
        decoder = decoder.cuda()
    
    _iter = load_model_for_eval(load, e1, e2, decoder)
    
    e1 = e1.eval()
    e2 = e2.eval()
    decoder = decoder.eval()
    return e1, e2, decoder
예제 #5
0
def eval(args):
    e1 = E1(args.sep, int((args.resize / 64)))
    e2 = E2(args.sep, int((args.resize / 64)))
    decoder = Decoder(int((args.resize / 64)))

    if torch.cuda.is_available():
        e1 = e1.cuda()
        e2 = e2.cuda()
        decoder = decoder.cuda()

    if args.load != '':
        save_file = os.path.join(args.load, 'checkpoint')
        _iter = load_model_for_eval(save_file, e1, e2, decoder)

    e1 = e1.eval()
    e2 = e2.eval()
    decoder = decoder.eval()

    if not os.path.exists(args.out) and args.out != "":
        os.mkdir(args.out)

    save_imgs(args, e1, e2, decoder, _iter)