Esempio n. 1
0
def main():
    # parse options
    parser = TrainOptions()
    opts = parser.parse()

    # data loader
    print('\n--- load dataset ---')

    if opts.multi_modal:
        dataset = dataset_unpair_multi(opts)
    else:
        dataset = dataset_unpair(opts)

    train_loader = torch.utils.data.DataLoader(dataset,
                                               batch_size=opts.batch_size,
                                               shuffle=True,
                                               num_workers=opts.nThreads)

    # model
    print('\n--- load model ---')
    model = DRIT(opts)
    model.setgpu(opts.gpu)
    if opts.resume is None:
        model.initialize()
        ep0 = -1
        total_it = 0
    else:
        ep0, total_it = model.resume(opts.resume)
    model.set_scheduler(opts, last_ep=ep0)
    ep0 += 1
    print('start the training at epoch %d' % (ep0))

    # saver for display and output
    saver = Saver(opts)

    # train
    print('\n--- train ---')
    max_it = 500000
    for ep in range(ep0, opts.n_ep):
        for it, (images_a, images_b) in enumerate(train_loader):
            if images_a.size(0) != opts.batch_size or images_b.size(
                    0) != opts.batch_size:
                continue

            # input data
            images_a = images_a.cuda(opts.gpu).detach()
            images_b = images_b.cuda(opts.gpu).detach()

            # update model
            if (it + 1) % opts.d_iter != 0 and it < len(train_loader) - 2:
                model.update_D_content(images_a, images_b)
                continue
            else:
                model.update_D(images_a, images_b)
                model.update_EG()

            # save to display file

            if not opts.no_display_img and not opts.multi_modal:
                saver.write_display(total_it, model)

            print('total_it: %d (ep %d, it %d), lr %08f' %
                  (total_it, ep, it, model.gen_opt.param_groups[0]['lr']))
            total_it += 1
            if total_it >= max_it:
                # saver.write_img(-1, model)
                saver.write_model(-1, model)
                break

        # decay learning rate
        if opts.n_ep_decay > -1:
            model.update_lr()

        # save result image
        if not opts.multi_modal:
            saver.write_img(ep, model)

        # Save network weights
        saver.write_model(ep, total_it, model)

    return
Esempio n. 2
0
def main():

    debug_mode=False

    # parse options
    parser = TrainOptions()
    opts = parser.parse()

    # daita loader
    print('\n--- load dataset ---')
    dataset = dataset_unpair(opts)
    train_loader = torch.utils.data.DataLoader(dataset, batch_size=opts.batch_size, shuffle=True,
                                               num_workers=opts.nThreads)
    '''
        通过检查dataset_unpair,我们发现:
            图像是先缩放到256,256,然后再随机裁剪出216,216的patch,(测试时是从中心裁剪)
    '''

    # model
    print('\n--- load model ---')
    model = DRIT(opts)
    if not debug_mode:
        model.setgpu(opts.gpu)
    if opts.resume is None:
        model.initialize()
        ep0 = -1
        total_it = 0
    else:
        ep0, total_it = model.resume(opts.resume)
    model.set_scheduler(opts, last_ep=ep0)
    ep0 += 1
    print('start the training at epoch %d' % (ep0))

    # saver for display and output
    saver = Saver(opts)

    # train
    print('\n--- train ---')
    max_it = 500000
    for ep in range(ep0, opts.n_ep):
        '''
            images_a,images_b: 2,3,216,216
        '''
        for it, (images_a, images_b) in enumerate(train_loader):
            #   假如正好拿到了残次的剩余的一两个样本,就跳过,重新取样
            if images_a.size(0) != opts.batch_size or images_b.size(0) != opts.batch_size:
                continue

            # input data
            if not debug_mode:
                images_a = images_a.cuda(opts.gpu).detach() #   这里进行detach,可能是为了避免计算不需要的梯度,节省显存
                images_b = images_b.cuda(opts.gpu).detach()

            # update model 按照默认设置,1/3的iter更新内容判别器,2/3的iter更新D和EG
            if (it + 1) % opts.d_iter != 0 and it < len(train_loader) - 2:
                model.update_D_content(images_a, images_b)
                continue
            else:
                model.update_D(images_a, images_b)
                model.update_EG()

            # save to display file
            if not opts.no_display_img:
                saver.write_display(total_it, model)

            print('total_it: %d (ep %d, it %d), lr %08f' % (total_it, ep, it, model.gen_opt.param_groups[0]['lr']))
            sys.stdout.flush()
            total_it += 1
            if total_it >= max_it:
                saver.write_img(-1, model)
                saver.write_model(-1, model)
                break

        # decay learning rate
        if opts.n_ep_decay > -1:
            model.update_lr()

        # save result image
        saver.write_img(ep, model)

        # Save network weights
        saver.write_model(ep, total_it, model)

    return