예제 #1
0
def main():
    # parse options
    parser = TestOptions()
    opts = parser.parse()

    # data loader
    train_loader, input_data_par = get_loader(1)

    # model
    print('\n--- load model ---')
    model = DRIT(opts)
    model.setgpu(opts.gpu)
    model.resume(opts.resume, train=False)
    model.eval()

    # directory
    result_dir = os.path.join(opts.result_dir, opts.name)
    if not os.path.exists(result_dir):
        os.mkdir(result_dir)

    # test
    print('\n--- testing ---')
    for it, (images_a, images_b, labels) in enumerate(train_loader['test']):
        images_a = images_a.cuda(opts.gpu).detach()
        images_b = images_b.cuda(opts.gpu).detach()
        with torch.no_grad():
            loss = model.test_model(images_a, images_b)
            print('it:{}, loss:{}'.format(it, loss))
    return
예제 #2
0
def main():
    # parse options
    parser = TestOptions()
    opts = parser.parse()

    # data loader
    print('\n--- load dataset ---')
    datasetA = dataset_single(opts, 'A', opts.input_dim_a)
    datasetB = dataset_single(opts, 'B', opts.input_dim_b)
    if opts.a2b:
        loader = torch.utils.data.DataLoader(datasetA,
                                             batch_size=1,
                                             num_workers=opts.nThreads)
        loader_attr = torch.utils.data.DataLoader(datasetB,
                                                  batch_size=1,
                                                  num_workers=opts.nThreads,
                                                  shuffle=True)
    else:
        loader = torch.utils.data.DataLoader(datasetB,
                                             batch_size=1,
                                             num_workers=opts.nThreads)
        loader_attr = torch.utils.data.DataLoader(datasetA,
                                                  batch_size=1,
                                                  num_workers=opts.nThreads,
                                                  shuffle=True)

    # model
    print('\n--- load model ---')
    model = DRIT(opts)
    model.setgpu(opts.gpu)
    model.resume(opts.resume, train=False)
    model.eval()

    # directory
    result_dir = os.path.join(opts.result_dir, opts.name)
    if not os.path.exists(result_dir):
        os.mkdir(result_dir)

    # test
    print('\n--- testing ---')
    for idx1, img1 in enumerate(loader):
        print('{}/{}'.format(idx1, len(loader)))
        img1 = img1.cuda(opts.gpu)
        imgs = [img1]
        names = ['input']
        for idx2, img2 in enumerate(loader_attr):
            if idx2 == opts.num:
                break
            img2 = img2.cuda(opts.gpu)
            with torch.no_grad():
                if opts.a2b:
                    img = model.test_forward_transfer(img1, img2, a2b=True)
                else:
                    img = model.test_forward_transfer(img2, img1, a2b=False)
            imgs.append(img)
            names.append('output_{}'.format(idx2))
        save_imgs(imgs, names, os.path.join(result_dir, '{}'.format(idx1)))

    return
예제 #3
0
def main():
  # parse options
  parser = TestOptions()
  opts = parser.parse()

  # data loader
  print('\n--- load dataset ---')
  if opts.a2b:
    dataset = dataset_single(opts, 'A', opts.input_dim_a)
    subdir = "a2b"
  else:
    dataset = dataset_single(opts, 'B', opts.input_dim_b)
    subdir = "b2a"
  loader = torch.utils.data.DataLoader(dataset, batch_size=1, num_workers=opts.nThreads)

  # model
  print('\n--- load model ---')
  model = DRIT(opts)
  model.setgpu(opts.gpu)
  model.resume(opts.resume, train=False)
  model.eval()

  # directory
  result_dir = os.path.join(opts.result_dir, opts.name, subdir)
  if not os.path.exists(result_dir):
    os.makedirs(result_dir)

  # test
  print('\n--- testing ---')
  for idx1, img1 in enumerate(loader):
    print('{}/{}'.format(idx1, len(loader)))
    img1 = img1.cuda()
    imgs = [img1]
    names = ['input']
    for idx2 in range(opts.num):
      with torch.no_grad():
        img = model.test_forward(img1, a2b=opts.a2b)
      imgs.append(img)
      names.append('output_{}'.format(idx2))
    save_imgs(imgs, names, os.path.join(result_dir, '{}'.format(idx1)))

  return
예제 #4
0
def main():
    # parse options
    parser = TestOptions()
    opts = parser.parse()

    # data loader
    print('\n--- load dataset ---')
    datasetA = dataset_single(opts, 'A', opts.input_dim_a)
    datasetB = dataset_single(opts, 'B', opts.input_dim_b)
    if opts.a2b:
        loader = torch.utils.data.DataLoader(datasetA,
                                             batch_size=1,
                                             num_workers=opts.nThreads)
        loader_attr = torch.utils.data.DataLoader(datasetB,
                                                  batch_size=1,
                                                  num_workers=opts.nThreads,
                                                  shuffle=True)
    else:
        loader = torch.utils.data.DataLoader(datasetB,
                                             batch_size=1,
                                             num_workers=opts.nThreads)
        loader_attr = torch.utils.data.DataLoader(datasetA,
                                                  batch_size=1,
                                                  num_workers=opts.nThreads,
                                                  shuffle=True)

    # model
    print('\n--- load model ---')
    model = DRIT(opts)
    model.setgpu(opts.gpu)
    model.resume(opts.resume, train=False)
    model.eval()

    # directory
    result_dir = os.path.join(opts.result_dir, opts.name)
    if not os.path.exists(result_dir):
        os.mkdir(result_dir)

    # test
    print('\n--- testing ---')
    for idx1, (img1, img1_path) in enumerate(loader):
        print('{}/{}'.format(idx1, len(loader)))
        img1_path = img1_path[0]
        img1_prefix = os.path.basename(img1_path).split('.')[0]
        #    print('img1_prefix:', img1_prefix)
        img1 = img1.cuda()
        imgs = [img1]
        #    print('img1 type:', type(img1))
        names = [f'{img1_prefix}_input']
        for idx2, (img2, img2_path) in enumerate(loader_attr):
            img2_path = img2_path[0]
            img2_prefix = os.path.basename(img2_path).split('.')[0]
            #      print('img2_prefix:', img2_prefix)
            if img1_prefix == img2_prefix:
                img2 = img2.cuda()
                imgs.append(img2)
                names.append(f'{img2_prefix}_real')
                #        print('img2 type:', type(img2))
                with torch.no_grad():
                    if opts.a2b:
                        img = model.test_forward_transfer(img1, img2, a2b=True)
                    else:
                        img = model.test_forward_transfer(img2,
                                                          img1,
                                                          a2b=False)
                imgs.append(img)
                names.append(f'{img2_prefix}_fake')
                break
        save_imgs(imgs, names, result_dir)

    return
예제 #5
0
파일: data_2.py 프로젝트: sunyue11/DRIT
opts.vocab_size = len(vocab)

test_loader = data.get_test_loader('test', opts.data_name, vocab,
                                   opts.crop_size, opts.batch_size,
                                   opts.workers, opts)

subspace = model_2.VSE(opts)
subspace.setgpu()
subspace.load_state_dict(torch.load(opts.resume2))
subspace.val_start()

# model
print('\n--- load model ---')
model = DRIT(opts)
model.setgpu(opts.gpu)
model.resume(opts.resume, train=False)
model.eval()

a = None
b = None
c = None
d = None
for it, (images, captions, lengths, ids) in enumerate(test_loader):
    if it >= opts.test_iter:
        break
    images = images.cuda(opts.gpu).detach()
    captions = captions.cuda(opts.gpu).detach()

    img_emb, cap_emb = subspace.forward_emb(images,
                                            captions,
                                            lengths,
예제 #6
0
def main():
    # parse options
    parser = TrainOptions()
    opts = parser.parse()

    # data loader
    print('\n--- load dataset ---')

    if opts.multi_modal:
        dataset = dataset_unpair_multi(opts)
    else:
        dataset = dataset_unpair(opts)

    train_loader = torch.utils.data.DataLoader(dataset,
                                               batch_size=opts.batch_size,
                                               shuffle=True,
                                               num_workers=opts.nThreads)

    # model
    print('\n--- load model ---')
    model = DRIT(opts)
    model.setgpu(opts.gpu)
    if opts.resume is None:
        model.initialize()
        ep0 = -1
        total_it = 0
    else:
        ep0, total_it = model.resume(opts.resume)
    model.set_scheduler(opts, last_ep=ep0)
    ep0 += 1
    print('start the training at epoch %d' % (ep0))

    # saver for display and output
    saver = Saver(opts)

    # train
    print('\n--- train ---')
    max_it = 500000
    for ep in range(ep0, opts.n_ep):
        for it, (images_a, images_b) in enumerate(train_loader):
            if images_a.size(0) != opts.batch_size or images_b.size(
                    0) != opts.batch_size:
                continue

            # input data
            images_a = images_a.cuda(opts.gpu).detach()
            images_b = images_b.cuda(opts.gpu).detach()

            # update model
            if (it + 1) % opts.d_iter != 0 and it < len(train_loader) - 2:
                model.update_D_content(images_a, images_b)
                continue
            else:
                model.update_D(images_a, images_b)
                model.update_EG()

            # save to display file

            if not opts.no_display_img and not opts.multi_modal:
                saver.write_display(total_it, model)

            print('total_it: %d (ep %d, it %d), lr %08f' %
                  (total_it, ep, it, model.gen_opt.param_groups[0]['lr']))
            total_it += 1
            if total_it >= max_it:
                # saver.write_img(-1, model)
                saver.write_model(-1, model)
                break

        # decay learning rate
        if opts.n_ep_decay > -1:
            model.update_lr()

        # save result image
        if not opts.multi_modal:
            saver.write_img(ep, model)

        # Save network weights
        saver.write_model(ep, total_it, model)

    return
예제 #7
0
파일: train.py 프로젝트: sunyue11/DRIT
def main():
    # parse options
    parser = TrainOptions()
    opts = parser.parse()

    # daita loader
    print('\n--- load dataset ---')
    vocab = pickle.load(
        open(os.path.join(opts.vocab_path, '%s_vocab.pkl' % opts.data_name),
             'rb'))
    vocab_size = len(vocab)
    opts.vocab_size = vocab_size
    torch.backends.cudnn.enabled = False
    # Load data loaders
    train_loader, val_loader = data.get_loaders(opts.data_name, vocab,
                                                opts.crop_size,
                                                opts.batch_size, opts.workers,
                                                opts)
    test_loader = data.get_test_loader('test', opts.data_name, vocab,
                                       opts.crop_size, opts.batch_size,
                                       opts.workers, opts)
    # model
    print('\n--- load subspace ---')
    subspace = model_2.VSE(opts)
    subspace.setgpu()
    print('\n--- load model ---')
    model = DRIT(opts)
    model.setgpu(opts.gpu)
    if opts.resume is None:  #之前没有保存过模型
        model.initialize()
        ep0 = -1
        total_it = 0
    else:
        ep0, total_it = model.resume(opts.resume)
    model.set_scheduler(opts, last_ep=ep0)
    ep0 += 1
    print('start the training at epoch %d' % (ep0))

    # saver for display and output
    saver = Saver(opts)

    # train
    print('\n--- train ---')
    max_it = 500000
    score = 0.0
    subspace.train_start()
    for ep in range(ep0, opts.pre_iter):
        print('-----ep:{} --------'.format(ep))
        for it, (images, captions, lengths, ids) in enumerate(train_loader):
            if it >= opts.train_iter:
                break
            # input data
            images = images.cuda(opts.gpu).detach()
            captions = captions.cuda(opts.gpu).detach()

            img, cap = subspace.train_emb(images,
                                          captions,
                                          lengths,
                                          ids,
                                          pre=True)  #[b,1024]

            subspace.pre_optimizer.zero_grad()
            img = img.view(images.size(0), -1, 32, 32)
            cap = cap.view(images.size(0), -1, 32, 32)

            model.pretrain_ae(img, cap)

            if opts.grad_clip > 0:
                clip_grad_norm(subspace.params, opts.grad_clip)

            subspace.pre_optimizer.step()

    for ep in range(ep0, opts.n_ep):
        subspace.train_start()
        adjust_learning_rate(opts, subspace.optimizer, ep)
        for it, (images, captions, lengths, ids) in enumerate(train_loader):
            if it >= opts.train_iter:
                break
            # input data
            images = images.cuda(opts.gpu).detach()
            captions = captions.cuda(opts.gpu).detach()

            img, cap = subspace.train_emb(images, captions, lengths,
                                          ids)  #[b,1024]

            img = img.view(images.size(0), -1, 32, 32)
            cap = cap.view(images.size(0), -1, 32, 32)

            subspace.optimizer.zero_grad()

            for p in model.disA.parameters():
                p.requires_grad = True
            for p in model.disB.parameters():
                p.requires_grad = True
            for p in model.disA_attr.parameters():
                p.requires_grad = True
            for p in model.disB_attr.parameters():
                p.requires_grad = True

            for i in range(opts.niters_gan_d):  #5
                model.update_D(img, cap)

            for p in model.disA.parameters():
                p.requires_grad = False
            for p in model.disB.parameters():
                p.requires_grad = False
            for p in model.disA_attr.parameters():
                p.requires_grad = False
            for p in model.disB_attr.parameters():
                p.requires_grad = False

            for i in range(opts.niters_gan_enc):
                model.update_E(img, cap)  #利用新的content损失函数

            subspace.optimizer.step()

            print('total_it: %d (ep %d, it %d), lr %09f' %
                  (total_it, ep, it, model.gen_opt.param_groups[0]['lr']))
            total_it += 1

        # decay learning rate
        if opts.n_ep_decay > -1:
            model.update_lr()

        # save result image
        #saver.write_img(ep, model)
        if (ep + 1) % opts.n_ep == 0:
            print('save model')
            filename = os.path.join(opts.result_dir, opts.name)
            model.save('%s/final_model.pth' % (filename), ep, total_it)
            torch.save(subspace.state_dict(),
                       '%s/final_subspace.pth' % (filename))
        elif (ep + 1) % 10 == 0:
            print('save model')
            filename = os.path.join(opts.result_dir, opts.name)
            model.save('%s/%s_model.pth' % (filename, str(ep + 1)), ep,
                       total_it)
            torch.save(subspace.state_dict(),
                       '%s/%s_subspace.pth' % (filename, str(ep + 1)))

        if (ep + 1) % opts.model_save_freq == 0:
            a = None
            b = None
            c = None
            d = None
            subspace.val_start()
            for it, (images, captions, lengths, ids) in enumerate(test_loader):
                if it >= opts.val_iter:
                    break
                images = images.cuda(opts.gpu).detach()
                captions = captions.cuda(opts.gpu).detach()

                img_emb, cap_emb = subspace.forward_emb(images,
                                                        captions,
                                                        lengths,
                                                        volatile=True)

                img = img_emb.view(images.size(0), -1, 32, 32)
                cap = cap_emb.view(images.size(0), -1, 32, 32)
                image1, text1 = model.test_model2(img, cap)
                img2 = image1.view(images.size(0), -1)
                cap2 = text1.view(images.size(0), -1)

                if a is None:
                    a = np.zeros(
                        (opts.val_iter * opts.batch_size, img_emb.size(1)))
                    b = np.zeros(
                        (opts.val_iter * opts.batch_size, cap_emb.size(1)))

                    c = np.zeros(
                        (opts.val_iter * opts.batch_size, img2.size(1)))
                    d = np.zeros(
                        (opts.val_iter * opts.batch_size, cap2.size(1)))

                a[ids] = img_emb.data.cpu().numpy().copy()
                b[ids] = cap_emb.data.cpu().numpy().copy()

                c[ids] = img2.data.cpu().numpy().copy()
                d[ids] = cap2.data.cpu().numpy().copy()

            aa = torch.from_numpy(a)
            bb = torch.from_numpy(b)

            cc = torch.from_numpy(c)
            dd = torch.from_numpy(d)

            (r1, r5, r10, medr, meanr) = i2t(aa, bb, measure=opts.measure)
            print('test640: subspace: Med:{}, r1:{}, r5:{}, r10:{}'.format(
                medr, r1, r5, r10))

            (r1i, r5i, r10i, medri, meanr) = t2i(aa, bb, measure=opts.measure)
            print('test640: subspace: Med:{}, r1:{}, r5:{}, r10:{}'.format(
                medri, r1i, r5i, r10i))

            (r2, r3, r4, m1, m2) = i2t(cc, dd, measure=opts.measure)
            print('test640: encoder: Med:{}, r1:{}, r5:{}, r10:{}'.format(
                m1, r2, r3, r4))

            (r2i, r3i, r4i, m1i, m2i) = t2i(cc, dd, measure=opts.measure)
            print('test640: encoder: Med:{}, r1:{}, r5:{}, r10:{}'.format(
                m1i, r2i, r3i, r4i))

            curr = r2 + r3 + r4 + r2i + r3i + r4i

            if curr > score:
                score = curr
                print('save model')
                filename = os.path.join(opts.result_dir, opts.name)
                model.save('%s/best_model.pth' % (filename), ep, total_it)
                torch.save(subspace.state_dict(),
                           '%s/subspace.pth' % (filename))

            a = None
            b = None
            c = None
            d = None

            for it, (images, captions, lengths, ids) in enumerate(test_loader):

                images = images.cuda(opts.gpu).detach()
                captions = captions.cuda(opts.gpu).detach()

                img_emb, cap_emb = subspace.forward_emb(images,
                                                        captions,
                                                        lengths,
                                                        volatile=True)

                img = img_emb.view(images.size(0), -1, 32, 32)
                cap = cap_emb.view(images.size(0), -1, 32, 32)
                image1, text1 = model.test_model2(img, cap)
                img2 = image1.view(images.size(0), -1)
                cap2 = text1.view(images.size(0), -1)

                if a is None:
                    a = np.zeros((len(test_loader.dataset), img_emb.size(1)))
                    b = np.zeros((len(test_loader.dataset), cap_emb.size(1)))

                    c = np.zeros((len(test_loader.dataset), img2.size(1)))
                    d = np.zeros((len(test_loader.dataset), cap2.size(1)))

                a[ids] = img_emb.data.cpu().numpy().copy()
                b[ids] = cap_emb.data.cpu().numpy().copy()

                c[ids] = img2.data.cpu().numpy().copy()
                d[ids] = cap2.data.cpu().numpy().copy()

            aa = torch.from_numpy(a)
            bb = torch.from_numpy(b)

            cc = torch.from_numpy(c)
            dd = torch.from_numpy(d)

            (r1, r5, r10, medr, meanr) = i2t(aa, bb, measure=opts.measure)
            print('test5000: subspace: Med:{}, r1:{}, r5:{}, r10:{}'.format(
                medr, r1, r5, r10))

            (r1i, r5i, r10i, medri, meanr) = t2i(aa, bb, measure=opts.measure)
            print('test5000: subspace: Med:{}, r1:{}, r5:{}, r10:{}'.format(
                medri, r1i, r5i, r10i))

            (r2, r3, r4, m1, m2) = i2t(cc, dd, measure=opts.measure)
            print('test5000: encoder: Med:{}, r1:{}, r5:{}, r10:{}'.format(
                m1, r2, r3, r4))

            (r2i, r3i, r4i, m1i, m2i) = t2i(cc, dd, measure=opts.measure)
            print('test5000: encoder: Med:{}, r1:{}, r5:{}, r10:{}'.format(
                m1i, r2i, r3i, r4i))

    return
예제 #8
0
파일: train.py 프로젝트: jiaoyiping630/DRIT
def main():

    debug_mode=False

    # parse options
    parser = TrainOptions()
    opts = parser.parse()

    # daita loader
    print('\n--- load dataset ---')
    dataset = dataset_unpair(opts)
    train_loader = torch.utils.data.DataLoader(dataset, batch_size=opts.batch_size, shuffle=True,
                                               num_workers=opts.nThreads)
    '''
        通过检查dataset_unpair,我们发现:
            图像是先缩放到256,256,然后再随机裁剪出216,216的patch,(测试时是从中心裁剪)
    '''

    # model
    print('\n--- load model ---')
    model = DRIT(opts)
    if not debug_mode:
        model.setgpu(opts.gpu)
    if opts.resume is None:
        model.initialize()
        ep0 = -1
        total_it = 0
    else:
        ep0, total_it = model.resume(opts.resume)
    model.set_scheduler(opts, last_ep=ep0)
    ep0 += 1
    print('start the training at epoch %d' % (ep0))

    # saver for display and output
    saver = Saver(opts)

    # train
    print('\n--- train ---')
    max_it = 500000
    for ep in range(ep0, opts.n_ep):
        '''
            images_a,images_b: 2,3,216,216
        '''
        for it, (images_a, images_b) in enumerate(train_loader):
            #   假如正好拿到了残次的剩余的一两个样本,就跳过,重新取样
            if images_a.size(0) != opts.batch_size or images_b.size(0) != opts.batch_size:
                continue

            # input data
            if not debug_mode:
                images_a = images_a.cuda(opts.gpu).detach() #   这里进行detach,可能是为了避免计算不需要的梯度,节省显存
                images_b = images_b.cuda(opts.gpu).detach()

            # update model 按照默认设置,1/3的iter更新内容判别器,2/3的iter更新D和EG
            if (it + 1) % opts.d_iter != 0 and it < len(train_loader) - 2:
                model.update_D_content(images_a, images_b)
                continue
            else:
                model.update_D(images_a, images_b)
                model.update_EG()

            # save to display file
            if not opts.no_display_img:
                saver.write_display(total_it, model)

            print('total_it: %d (ep %d, it %d), lr %08f' % (total_it, ep, it, model.gen_opt.param_groups[0]['lr']))
            sys.stdout.flush()
            total_it += 1
            if total_it >= max_it:
                saver.write_img(-1, model)
                saver.write_model(-1, model)
                break

        # decay learning rate
        if opts.n_ep_decay > -1:
            model.update_lr()

        # save result image
        saver.write_img(ep, model)

        # Save network weights
        saver.write_model(ep, total_it, model)

    return