Beispiel #1
0
def main():
    plt.ion()  # 开启interactive mode,便于连续plot
    opt = TrainOptions().parse()
    # 用于计算的设备 CPU or GPU
    device = torch.device("cuda" if USE_CUDA else "cpu")

    # 定义判别器与生成器的网络
    #net_d = NLayerDiscriminator(opt.output_nc, opt.ndf, n_layers=3)#batchnorm
    #net_d = Discriminator(opt.output_nc)
    net_d_ct =networks.define_D(opt.output_nc, opt.ndf, opt.netD,
                                            opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, opt.gpu_ids)
    net_d_dr =networks.define_D(opt.input_nc, opt.ndf, 'ProjNet',
                                             opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, opt.gpu_ids)
    net_g_dr=networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG, opt.norm,
                                        not opt.no_dropout, opt.init_type, opt.init_gain, opt.gpu_ids)

    #net_g = CTGenerator(opt.input_nc, opt.output_nc, opt.ngf, n_blocks=6)
    net_g_ct = networks.define_G(1, 65, opt.ngf, 'CTnet', opt.norm,
                      not opt.no_dropout, opt.init_type, opt.init_gain, opt.gpu_ids)
    # init_weights(net_d_dr)
    # init_weights(net_d_ct)
    # init_weights(net_g_dr)
    # init_weights(net_g_ct)
    net_d_ct.to(device)
    net_d_dr.to(device)
    net_g_dr.to(device)
    net_g_ct.to(device)

    one = torch.FloatTensor([1])
    mone = one * -1
    one = one.to(device)
    mone= mone.to(device)
    #summary(net_g_dr, (2,65, 65,65))
    if load_net:
        # save_filename = 'net_d%s.pth' % epoch_start
        # save_path = os.path.join('./check/', save_filename)
        # load_network(net_d, save_path)
        save_filename = 'net_g%s.pth' % epoch_start
        save_path = os.path.join('./check/', save_filename)
        load_network(net_g_ct, save_path)
    # 损失函数
    #criterion = nn.BCELoss().to(device)
    criterion = nn.MSELoss().to(device)
    criterion1 = nn.L1Loss().to(device)

    # 优化器
    optimizer_d = torch.optim.Adam(itertools.chain(net_d_ct.parameters(),net_d_dr.parameters()), lr=0.0001,betas=[0.5,0.9])
    optimizer_g = torch.optim.Adam(itertools.chain(net_g_ct.parameters(),net_g_dr.parameters()), lr=0.0001,betas=[0.5,0.9])

    #optimizer_d = torch.optim.AdamW(net_d.parameters(), lr=0.0001)
    #optimizer_g = torch.optim.AdamW(net_g.parameters(), lr=0.0001)
    #one = torch.FloatTensor([1]).cuda()
    #mone = one * -1
    dataset = create_dataset(opt)  # create a dataset given opt.dataset_mode and other options
    dataset_size = len(dataset)  # get the number of images in the dataset.
    def gensample():
        for image in enumerate(dataset):
            yield image
    gen=gensample()
    ii=0

    for epoch in range(MAX_EPOCH):
        # 为真实数据加上噪声
        for it,data in enumerate(dataset):


            #载入数据
            dr_real = autograd.Variable(data['A'].cuda())
            dr_real=dr_real.squeeze(0)
            ct_real = autograd.Variable(data['B'].cuda())
            ct_real=ct_real.squeeze(0)
            # 训练
            #内循环
            freeze_params(net_g_ct)
            freeze_params(net_g_dr)
            unfreeze_params(net_d_ct)
            unfreeze_params(net_d_dr)

            ct_fake = autograd.Variable(net_g_ct(dr_real).data)
            dr_fake = autograd.Variable(net_g_dr(ct_real).data)

            optimizer_d.zero_grad()
            loss_dsc_realct = net_d_ct(ct_real).mean()
            #loss_dsc_realct.backward()
            loss_dsc_fakect = net_d_ct(ct_fake.detach()).mean()
            #loss_dsc_fakect.backward()
            gradient_penalty_ct = calc_gradient_penalty(net_d_ct, ct_real, ct_fake)
            #gradient_penalty_ct.backward()
            loss_d_ct=loss_dsc_fakect - loss_dsc_realct+gradient_penalty_ct
            loss_d_ct.backward()
            Wd_ct=loss_dsc_realct-loss_dsc_fakect

            loss_dsc_realdr = net_d_dr(dr_real).mean()
            #loss_dsc_realdr.backward()
            loss_dsc_fakedr = net_d_dr(dr_fake.detach()).mean()
            #loss_dsc_fakedr.backward()
            gradient_penalty_dr = calc_gradient_penalty(net_d_dr, dr_real, dr_fake)
            #gradient_penalty_dr.backward()
            loss_d_dr = loss_dsc_fakedr - loss_dsc_realdr + gradient_penalty_dr
            loss_d_dr.backward()
            Wd_dr = loss_dsc_realdr - loss_dsc_fakedr
            optimizer_d.step()
            if it%CRITIC_ITERS==0:
            #if True:
                unfreeze_params(net_g_ct)
                freeze_params(net_d_ct)
                unfreeze_params(net_g_dr)
                freeze_params(net_d_dr)


                ct_fake_g=net_g_ct(dr_real)
                dr_fake_g=net_g_dr(ct_real)

                #外循环ct_dr
                # optimizer_g.zero_grad()
                loss_out_dr=criterion1(net_g_ct(dr_fake_g),ct_real)
                # loss_out_dr.backward()
                # optimizer_g.step()
                #net_g_ct.load_state_dict(dict_g_ct)

                #外循环dr_ct
                # optimizer_g.zero_grad()
                loss_out_ct=criterion1(net_g_dr(ct_fake_g),dr_real)
                # loss_out_ct.backward()
                # optimizer_g.step()

                #内循环gan
                loss_g_ct = - net_d_ct(ct_fake_g).mean()
                #loss_g_ct.backward()
                loss_g_dr = - net_d_dr(dr_fake_g).mean()
                #loss_g_dr.backward()
                loss_gan = loss_out_dr + loss_out_ct
                #loss_gan=loss_out_dr+loss_out_ct+loss_g_ct+loss_g_dr
                #loss_gan = criterion(net_g_ct(dr_real), ct_real) + criterion(net_g_dr(ct_real), dr_real)
                optimizer_g.zero_grad()
                loss_gan.backward()
                optimizer_g.step()

            if it%1==0:
                fk_im=toimage(torch.irfft(torch.roll(torch.roll(ct_fake,-32,2),-32,3).permute(1, 2, 3, 0), 2, onesided=False)[32, :, :].unsqueeze(0))

                #fk_im=toimage(ct_fake[0,32,:,:].unsqueeze(0))
                 #img_test.append(fk_im)
                save_filenamet = 'fakect%s.bmp' % int(epoch/dataset_size)
                img_path = os.path.join('./check/img/', save_filenamet)
                save_image(fk_im, img_path)

                rel_im=toimage(torch.irfft(torch.roll(torch.roll(ct_real,-32,2),-32,3).permute(1, 2, 3, 0), 2, onesided=False)[32, :, :].unsqueeze(0))
                #rel_im = toimage(ct_real[0,32, :, :].unsqueeze(0))
                # img_test.append(rel_im)
                save_image(rel_im, os.path.join('./check/img/', 'Realct%s.bmp' % int(epoch)))

                fake_im = toimage(torch.irfft(torch.roll(torch.roll(dr_fake, -128, 2), -128, 3).permute(1, 2, 3, 0), 2, onesided=False))
                #fake_im =toimage(dr_fake.squeeze(0))
                save_image(fake_im, os.path.join('./check/img/', 'fakedr%s.bmp' % int(epoch)))
                ceshi(net_g_ct)
                message = '(epoch: %d, iters: %d, D_ct: %.3f;[real:%.3f;fake:%.3f], G_ct: %.3f, D_dr: %.3f, G_dr: %.3f) ' % (int(epoch), ii,loss_d_ct,loss_dsc_realct,loss_dsc_fakect,loss_g_ct,loss_d_dr,loss_g_dr)
                print(message)

        save_filename = 'net_g%s.pth' % epoch
        save_path = os.path.join('./check/', save_filename)
        torch.save(net_g_ct.cpu().state_dict(), save_path)
        net_g_ct.cuda(0)
from options.train_options import TrainOptions
from data.data_loader import CreateDataLoader
from models.models import create_model
import numpy as np
import os

opt = TrainOptions().parse()
opt.nThreads = 1
opt.batchSize = 1
opt.serial_batches = True
opt.no_flip = True
opt.instance_feat = True
opt.continue_train = True

name = 'features'
save_path = os.path.join(opt.checkpoints_dir, opt.name)

############ Initialize #########
data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()
dataset_size = len(data_loader)
model = create_model(opt)

########### Encode features ###########
reencode = True
if reencode:
	features = {}
	for label in range(opt.label_nc):
		features[label] = np.zeros((0, opt.feat_num+1))
	for i, data in enumerate(dataset):
	    feat = model.module.encode_features(data['image'], data['inst'])
Beispiel #3
0
def train_model(dataset_dir, checkpoint_dir):
    opt = TrainOptions().parse()  # get training options
    opt.dataroot = dataset_dir
    opt.checkpoints_dir = checkpoint_dir
    dataset = create_dataset(
        opt)  # create a dataset given opt.dataset_mode and other options
    dataset_size = len(dataset)  # get the number of images in the dataset.
    print('The number of training images = %d' % dataset_size)

    model = create_model(
        opt)  # create a model given opt.model and other options
    model.setup(
        opt)  # regular setup: load and print networks; create schedulers
    visualizer = Visualizer(
        opt)  # create a visualizer that display/save images and plots
    total_iters = 0  # the total number of training iterations

    for epoch in range(
            opt.epoch_count, opt.niter + opt.niter_decay + 1
    ):  # outer loop for different epochs; we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>
        epoch_start_time = time.time()  # timer for entire epoch
        iter_data_time = time.time()  # timer for data loading per iteration
        epoch_iter = 0  # the number of training iterations in current epoch, reset to 0 every epoch

        for i, data in enumerate(dataset):  # inner loop within one epoch
            iter_start_time = time.time(
            )  # timer for computation per iteration
            if total_iters % opt.print_freq == 0:
                t_data = iter_start_time - iter_data_time
            visualizer.reset()
            total_iters += opt.batch_size
            epoch_iter += opt.batch_size
            model.set_input(
                data)  # unpack data from dataset and apply preprocessing
            model.optimize_parameters(
            )  # calculate loss functions, get gradients, update network weights

            if total_iters % opt.display_freq == 0:  # display images on visdom and save images to a HTML file
                save_result = total_iters % opt.update_html_freq == 0
                model.compute_visuals()
                visualizer.display_current_results(model.get_current_visuals(),
                                                   epoch, save_result)

            if total_iters % opt.print_freq == 0:  # print training losses and save logging information to the disk
                losses = model.get_current_losses()
                t_comp = (time.time() - iter_start_time) / opt.batch_size
                visualizer.print_current_losses(epoch, epoch_iter, losses,
                                                t_comp, t_data)
                if opt.display_id > 0:
                    visualizer.plot_current_losses(
                        epoch,
                        float(epoch_iter) / dataset_size, losses)

            if total_iters % opt.save_latest_freq == 0:  # cache our latest model every <save_latest_freq> iterations
                print('saving the latest model (epoch %d, total_iters %d)' %
                      (epoch, total_iters))
                save_suffix = 'iter_%d' % total_iters if opt.save_by_iter else 'latest'
                model.save_networks(save_suffix)

            iter_data_time = time.time()
        if epoch % opt.save_epoch_freq == 0:  # cache our model every <save_epoch_freq> epochs
            print('saving the model at the end of epoch %d, iters %d' %
                  (epoch, total_iters))
            model.save_networks('latest')
            model.save_networks(epoch)

        print('End of epoch %d / %d \t Time Taken: %d sec' %
              (epoch, opt.niter + opt.niter_decay,
               time.time() - epoch_start_time))
        model.update_learning_rate(
        )  # update learning rates at the end of every epoch.
Beispiel #4
0
import os
import tensorflow as tf
from net.network import GMCNNModel
from data.data import DataLoader
from options.train_options import TrainOptions

config = TrainOptions().parse()

model = GMCNNModel()

# training data
# print(config.img_shapes)
dataLoader = DataLoader(filename=config.dataset_path,
                        batch_size=config.batch_size,
                        im_size=config.img_shapes)
images = dataLoader.next()
g_vars, d_vars, losses = model.build_net(images, config=config)

lr = tf.get_variable('lr',
                     shape=[],
                     trainable=False,
                     initializer=tf.constant_initializer(config.lr))

g_optimizer = tf.train.AdamOptimizer(lr, beta1=0.5, beta2=0.9)
d_optimizer = g_optimizer

g_train_op = g_optimizer.minimize(losses['g_loss'], var_list=g_vars)
d_train_op = d_optimizer.minimize(losses['d_loss'], var_list=d_vars)

saver = tf.train.Saver(max_to_keep=20, keep_checkpoint_every_n_hours=1)
Beispiel #5
0
def main():
    plt.ion()  # 开启interactive mode,便于连续plot
    opt = TrainOptions().parse()
    # 用于计算的设备 CPU or GPU
    device = torch.device("cuda" if USE_CUDA else "cpu")
    # 定义判别器与生成器的网络
    #net_d = NLayerDiscriminator(opt.output_nc, opt.ndf, n_layers=3)#batchnorm
    net_d = Discriminator(opt.output_nc)
    init_weights(net_d)
    net_g = CTGenerator(opt.input_nc, opt.output_nc, opt.ngf, n_blocks=6)
    init_weights(net_g)
    net_d.to(device)
    net_g.to(device)
    if load_net:
        save_filename = 'net_d%s.pth' % epoch_start
        save_path = os.path.join('./check/', save_filename)
        load_network(net_d, save_path)
        save_filename = 'net_g%s.pth' % epoch_start
        save_path = os.path.join('./check/', save_filename)
        load_network(net_g, save_path)
    # 损失函数
    criterion = nn.BCELoss().to(device)
    # 真假数据的标签
    true_lable = Variable(torch.ones(BATCH_SIZE)).to(device)
    fake_lable = Variable(torch.zeros(BATCH_SIZE)).to(device)
    # 优化器
    optimizer_d = torch.optim.Adam(net_d.parameters(),
                                   lr=0.0008,
                                   betas=[0.3, 0.9])
    optimizer_g = torch.optim.Adam(net_g.parameters(),
                                   lr=0.0008,
                                   betas=[0.3, 0.9])

    #optimizer_d = torch.optim.AdamW(net_d.parameters(), lr=0.0001)
    #optimizer_g = torch.optim.AdamW(net_g.parameters(), lr=0.0001)
    #one = torch.FloatTensor([1]).cuda()
    #mone = one * -1
    dataset = create_dataset(
        opt)  # create a dataset given opt.dataset_mode and other options
    dataset_size = len(dataset)  # get the number of images in the dataset.
    for epoch in range(MAX_EPOCH):
        # 为真实数据加上噪声
        for ii, data in enumerate(dataset):
            #real_data = np.vstack([POINT*POINT + np.random.normal(0, 0.01, SAMPLE_NUM) for _ in range(BATCH_SIZE)])
            #real_data = np.vstack([np.sin(POINT) + np.random.normal(0, 0.01, SAMPLE_NUM) for _ in range(BATCH_SIZE)])
            #real_data = Variable(torch.Tensor(real_data)).to(device)
            # 用随机噪声作为生成器的输入
            #g_noises = np.random.randn(BATCH_SIZE, N_GNET)
            #g_noises = Variable(torch.Tensor(g_noises)).to(device)
            g_noises = data['A'].cuda()
            real_data = data['B'].cuda()

            # 训练辨别器
            # for p in net_d.parameters():  # reset requires_grad
            #     p.requires_grad = True  # they are set to False below in netG update

            optimizer_d.zero_grad()
            # 辨别器辨别真图的loss
            d_real = net_d(real_data)
            #loss_d_real = criterion(d_real, true_lable)
            loss_d_real = -d_real.mean()
            #loss_d_real.backward()
            # 辨别器辨别假图的loss
            fake_date = net_g(g_noises)
            d_fake = net_d(fake_date.detach())
            #loss_d_fake = criterion(d_fake, fake_lable)
            loss_d_fake = d_fake.mean()
            #loss_d_fake.backward()

            # train with gradient penalty
            gradient_penalty = calc_gradient_penalty(net_d, real_data,
                                                     fake_date)
            #gradient_penalty.backward()

            D_cost = loss_d_fake + loss_d_real + gradient_penalty
            D_cost.backward()
            Wasserstein_D = loss_d_real - loss_d_fake
            optimizer_d.step()
            if ii % CRITIC_ITERS == 0:
                # 训练生成器
                # for p in net_d.parameters():
                #     p.requires_grad = False  # to avoid computation
                optimizer_g.zero_grad()
                fake_date = net_g(g_noises)
                d_fake = net_d(fake_date)
                # 生成器生成假图的loss
                #loss_g = criterion(d_fake, true_lable)
                loss_g = -d_fake.mean()
                loss_g.backward()
                optimizer_g.step()
                G_cost = -loss_g
                for name, parms in net_g.named_parameters():
                    if name == 'model.2.weight':
                        print('层:',name,parms.size(),'-->name:', name, '-->grad_requirs:', parms.requires_grad, \
                          ' -->grad_value:', parms.grad[0])
                a = 1
                # for name, parms in self.netG_A.named_parameters():

            # 每200步画出生成的数字图片和相关的数据
            if ii % 10 == 0:
                #print(fake_date[0]) plt.ion()
                plt.ion()
                plt.cla()
                # plt.plot(POINT, fake_date[0].to('cpu').detach().numpy(), c='#4AD631', lw=2,
                #          label="generated line")  # 生成网络生成的数据
                # plt.plot(POINT, real_data[0].to('cpu').detach().numpy(), c='#74BCFF', lw=3, label="real sin")  # 真实数据
                #prob = (loss_d_real.mean() + 1 - loss_d_fake.mean()) / 2.

                img_test = []
                fk_im = toimage(
                    torch.irfft(fake_date.permute(1, 2, 3, 0),
                                2,
                                onesided=False)[32, :, :].unsqueeze(0))
                img_test.append(fk_im)
                save_filenamet = 'fake%s.bmp' % epoch
                img_path = os.path.join('./check/img/', save_filenamet)
                save_image(fk_im, img_path)
                rel_im = toimage(
                    torch.irfft(real_data.squeeze(0).permute(1, 2, 3, 0),
                                2,
                                onesided=False)[32, :, :].unsqueeze(0))
                img_test.append(rel_im)

                for it in range(1, 3):
                    plt.subplot(1, 2, it)
                    plt.imshow(img_test[it - 1])
                plt.text(-1,
                         81,
                         'D accuracy=%.2f ' % (D_cost.mean()),
                         fontdict={'size': 15})
                plt.text(-1,
                         85,
                         'G accuracy=%.2f ' % (G_cost),
                         fontdict={'size': 15})
                plt.text(-1,
                         89,
                         'W accuracy=%.2f ' % (Wasserstein_D),
                         fontdict={'size': 15})
                plt.text(-1,
                         95,
                         'epoch=%.2f ' % (epoch),
                         fontdict={'size': 15})
                plt.show()
                # plt.ylim(-2, 2)
                plt.draw(), plt.pause(0.1), plt.clf()
        save_filename = 'net_d%s.pth' % epoch
        save_path = os.path.join('./check/', save_filename)
        torch.save(net_d.cpu().state_dict(), save_path)
        net_d.cuda(0)
        save_filename = 'net_g%s.pth' % epoch
        save_path = os.path.join('./check/', save_filename)
        torch.save(net_g.cpu().state_dict(), save_path)
        net_g.cuda(0)

    plt.ioff()
    plt.show()
def main():
    # torch.manual_seed(1234)
    # torch.cuda.manual_seed(1234)
    opt = TrainOptions()
    args = opt.initialize()

    _t = {'iter time': Timer()}

    model_name = args.source + '_to_' + args.target
    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)
        os.makedirs(os.path.join(args.snapshot_dir, 'logs'))
    opt.print_options(args)

    sourceloader, targetloader = CreateSrcDataLoader(
        args), CreateTrgDataLoader(args)
    targetloader_iter, sourceloader_iter = iter(targetloader), iter(
        sourceloader)

    model, optimizer = CreateModel(args)
    model_D, optimizer_D = CreateDiscriminator(args)

    start_iter = 0
    warm_start = 0

    if args.restore_from is not None:
        start_iter = int(args.restore_from.rsplit('/', 1)[1].rsplit('_')[1])

    train_writer = tensorboardX.SummaryWriter(
        os.path.join(args.snapshot_dir, "logs", model_name))

    bce_loss = torch.nn.BCEWithLogitsLoss()
    l1_loss = torch.nn.L1Loss()
    cos_loss = torch.nn.CosineSimilarity(dim=0, eps=1e-06)
    cent_loss = EntropyLoss()
    weighted_bce_loss = WeightedBCEWithLogitsLoss()

    cudnn.enabled = True
    cudnn.benchmark = True
    model.train()
    model.cuda()
    model_D.train()
    model_D.cuda()
    loss = [
        'loss_seg_src', 'loss_seg_trg', 'loss_D_trg_fake', 'loss_D_src_real',
        'loss_D_trg_real'
    ]
    _t['iter time'].tic()

    pbar = tqdm(range(start_iter, args.num_steps_stop))
    #for i in range(start_iter, args.num_steps):
    for i in pbar:

        model.adjust_learning_rate(args, optimizer, i)
        model_D.adjust_learning_rate(args, optimizer_D, i)

        optimizer.zero_grad()
        optimizer_D.zero_grad()
        for param in model_D.parameters():
            param.requires_grad = False

        src_img, src_lbl, _, _ = sourceloader_iter.next()
        src_img, src_lbl = Variable(src_img).cuda(), Variable(
            src_lbl.long()).cuda()
        src_seg_score, src_seg_score2 = model(src_img)
        loss_seg_src1 = CrossEntropy2d(src_seg_score, src_lbl)
        loss_seg_src2 = CrossEntropy2d(src_seg_score2, src_lbl)
        loss_seg_src = loss_seg_src1 + loss_seg_src2
        loss_seg_src.backward()

        if args.data_label_folder_target is not None:
            trg_img, trg_lbl, _, _ = targetloader_iter.next()
            trg_img, trg_lbl = Variable(trg_img).cuda(), Variable(
                trg_lbl.long()).cuda()
            trg_seg_score = model(trg_img)
            loss_seg_trg = model.loss
        else:
            trg_img, _, name = targetloader_iter.next()
            trg_img = Variable(trg_img).cuda()
            trg_seg_score, trg_seg_score2 = model(trg_img)
            (h, w) = trg_img.shape[-2:]

        weight_map = weightmap(F.softmax(trg_seg_score, dim=1),
                               F.softmax(trg_seg_score2, dim=1))

        # Use 2cls prediction as indicator of confident region
        indicator = 1 - weight_map
        tmp = np.histogram(indicator.cpu().detach().numpy(), bins=20)
        threshold = tmp[1][-2]
        mask = (indicator > threshold)
        mask = mask.repeat(1, 19, 1, 1)

        loss_seg_trg1 = cent_loss(trg_seg_score, mask)
        loss_seg_trg2 = cent_loss(trg_seg_score2, mask)
        loss_seg_trg = loss_seg_trg1 + loss_seg_trg2

        outD_trg = model_D(F.softmax(trg_seg_score + trg_seg_score2, dim=1))
        outD_trg = nn.functional.upsample(outD_trg, (h, w),
                                          mode='bilinear',
                                          align_corners=True)
        #ipdb.set_trace()

        #Adaptive Adversarial Loss
        if (i > warm_start):
            loss_D_trg_fake = weighted_bce_loss(
                outD_trg,
                Variable(torch.FloatTensor(
                    outD_trg.data.size()).fill_(0)).cuda(), weight_map,
                args.epsilon, args.lambda_local)
        else:
            loss_D_trg_fake = bce_loss(
                outD_trg,
                Variable(torch.FloatTensor(
                    outD_trg.data.size()).fill_(0)).cuda())

        #loss_agree= l1_loss(F.softmax(trg_seg_score), F.softmax(trg_seg_score2))

        loss_trg = args.lambda_adv_target * loss_D_trg_fake + args.tar_vat * loss_seg_trg
        loss_trg.backward()
        #Weight Discrepancy Loss

        W5 = None
        W6 = None
        if args.model == 'DeepLab2':

            for (w5, w6) in zip(model.layer5.parameters(),
                                model.layer6.parameters()):
                if W5 is None and W6 is None:
                    W5 = w5.view(-1)
                    W6 = w6.view(-1)
                else:
                    W5 = torch.cat((W5, w5.view(-1)), 0)
                    W6 = torch.cat((W6, w6.view(-1)), 0)

        #ipdb.set_trace()
        #loss_weight = (torch.matmul(W5, W6) / (torch.norm(W5) * torch.norm(W6)) + 1) # +1 is for a positive loss
        # loss_weight = loss_weight  * damping * 2
        loss_weight = args.weight_div * (cos_loss(W5, W6) + 1)
        loss_weight.backward()

        for param in model_D.parameters():
            param.requires_grad = True

        src_seg_score, src_seg_score2 = src_seg_score.detach(
        ), src_seg_score2.detach()

        outD_src = model_D(F.softmax(src_seg_score + src_seg_score2, dim=1))
        loss_D_src_real = bce_loss(
            outD_src,
            Variable(torch.FloatTensor(outD_src.data.size()).fill_(0)).cuda())
        loss_D_src_real.backward()

        trg_seg_score, trg_seg_score2 = trg_seg_score.detach(
        ), trg_seg_score2.detach()
        weight_map = weight_map.detach()

        outD_trg = model_D(F.softmax(trg_seg_score + trg_seg_score2, dim=1))
        outD_trg = nn.functional.upsample(outD_trg, (h, w),
                                          mode='bilinear',
                                          align_corners=True)
        #Adaptive Adversarial Loss
        if (i > warm_start):
            loss_D_trg_real = weighted_bce_loss(
                outD_trg,
                Variable(torch.FloatTensor(
                    outD_trg.data.size()).fill_(1)).cuda(), weight_map,
                args.epsilon, args.lambda_local)
        else:
            loss_D_trg_real = bce_loss(
                outD_trg,
                Variable(torch.FloatTensor(
                    outD_trg.data.size()).fill_(1)).cuda())

        loss_D_trg_real.backward()

        d_loss = loss_D_src_real.data + loss_D_trg_real.data

        optimizer.step()
        optimizer_D.step()

        for m in loss:
            train_writer.add_scalar(m, eval(m), i + 1)

        if (i + 1) % args.save_pred_every == 0:
            print 'taking snapshot ...'
            torch.save(
                model.state_dict(),
                os.path.join(args.snapshot_dir,
                             '%s_' % (args.source) + str(i + 1) + '.pth'))
            torch.save(
                model_D.state_dict(),
                os.path.join(args.snapshot_dir,
                             '%s_' % (args.source) + str(i + 1) + '_D.pth'))

        if (i + 1) % args.print_freq == 0:
            _t['iter time'].toc(average=False)
            print '[it %d][src seg loss %.4f][adv loss %.4f][d loss %.4f][div loss %.4f][lr %.4f][%.2fs]' % \
                    (i + 1, loss_seg_src.data, loss_D_trg_fake.data,d_loss,loss_weight.data, optimizer.param_groups[0]['lr']*10000, _t['iter time'].diff)
            if i + 1 > args.num_steps_stop:
                print 'finish training'
                break
            _t['iter time'].tic()
Beispiel #7
0
def main():
    opt = TrainOptions()
    args = opt.initialize()
    os.environ["CUDA_VISIBLE_DEVICES"] = args.GPU
    _t = {'iter time': Timer()}

    model_name = args.source + '_to_' + args.target
    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)
        os.makedirs(os.path.join(args.snapshot_dir, 'logs'))
    opt.print_options(args)

    sourceloader, targetloader = CreateSrcDataLoader(
        args), CreateTrgDataLoader(args)
    sourceloader_iter, targetloader_iter = iter(sourceloader), iter(
        targetloader)

    pseudotrgloader = CreatePseudoTrgLoader(
        args
    )  # Pseudo labels generated from previous round are used as target.
    pseudoloader_iter = iter(pseudotrgloader)

    model, optimizer = CreateModel(args)

    start_iter = 0
    if args.restore_from is not None:
        start_iter = int(args.restore_from.rsplit('/', 1)[1].rsplit('_')[1])

    cudnn.enabled = True
    cudnn.benchmark = True

    model.train()
    model.cuda()

    # losses to log
    loss = ['loss_seg_src', 'loss_seg_psu']
    loss_train = 0.0
    loss_val = 0.0
    loss_pseudo = 0.0
    loss_train_list = []
    loss_val_list = []
    loss_pseudo_list = []

    mean_img = torch.zeros(1, 1)
    class_weights = Variable(CS_weights).cuda()

    _t['iter time'].tic()
    for i in range(start_iter, args.num_steps):
        model.adjust_learning_rate(args, optimizer, i)  # adjust learning rate
        optimizer.zero_grad()  # zero grad

        src_img, src_lbl, _, _ = sourceloader_iter.next()  # new batch source
        trg_img, trg_lbl, _, _ = targetloader_iter.next()  # new batch target
        psu_img, psu_lbl, _, _ = pseudoloader_iter.next()

        scr_img_copy = src_img.clone()

        if mean_img.shape[-1] < 2:
            B, C, H, W = src_img.shape
            mean_img = IMG_MEAN.repeat(B, 1, H, W)

        #-------------------------------------------------------------------#

        # 1. source to target, target to target
        src_in_trg = FDA_source_to_target(src_img, trg_img,
                                          L=args.LB)  # src_lbl
        trg_in_trg = trg_img

        # 2. subtract mean
        src_img = src_in_trg.clone() - mean_img  # src_1, trg_1, src_lbl
        trg_img = trg_in_trg.clone() - mean_img  # trg_1, trg_0, trg_lbl
        psu_img = psu_img.clone() - mean_img

        #-------------------------------------------------------------------#

        # evaluate and update params #####
        src_img, src_lbl = Variable(src_img).cuda(), Variable(
            src_lbl.long()).cuda()  # to gpu
        src_seg_score = model(src_img,
                              lbl=src_lbl,
                              weight=class_weights,
                              ita=args.ita)  # forward pass
        loss_seg_src = model.loss_seg  # get loss
        loss_ent_src = model.loss_ent

        # use pseudo label as supervision
        psu_img, psu_lbl = Variable(psu_img).cuda(), Variable(
            psu_lbl.long()).cuda()
        psu_seg_score = model(psu_img,
                              lbl=psu_lbl,
                              weight=class_weights,
                              ita=args.ita)
        loss_seg_psu = model.loss_seg
        loss_ent_psu = model.loss_ent

        loss_all = loss_seg_src + (loss_seg_psu + args.entW * loss_ent_psu
                                   )  # loss of seg on src, and ent on s and t
        loss_all.backward()
        optimizer.step()

        loss_train += loss_seg_src.detach().cpu().numpy()
        loss_val += loss_seg_psu.detach().cpu().numpy()

        if (i + 1) % args.save_pred_every == 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                os.path.join(args.snapshot_dir,
                             '%s_' % (args.source) + str(i + 1) + '.pth'))

        if (i + 1) % args.print_freq == 0:
            _t['iter time'].toc(average=False)
            print('[it %d][src seg loss %.4f][psu seg loss %.4f][lr %.4f][%.2fs]' % \
                    (i + 1, loss_seg_src.data, loss_seg_psu.data, optimizer.param_groups[0]['lr']*10000, _t['iter time'].diff) )

            sio.savemat(args.tempdata, {
                'src_img': src_img.cpu().numpy(),
                'trg_img': trg_img.cpu().numpy()
            })

            loss_train /= args.print_freq
            loss_val /= args.print_freq
            loss_train_list.append(loss_train)
            loss_val_list.append(loss_val)
            sio.savemat(args.matname, {
                'loss_train': loss_train_list,
                'loss_val': loss_val_list
            })
            loss_train = 0.0
            loss_val = 0.0

            if i + 1 > args.num_steps_stop:
                print('finish training')
                break
            _t['iter time'].tic()
Beispiel #8
0
from options.train_options import TrainOptions
from augments.augs import load_augments
import os
import os.path as osp
import random
import cv2

if __name__ == '__main__':
    opts = TrainOptions()
    args = opts.initialize()

    DATA_DIR = args['experiment'].data_dir
    CLASS = args['experiment'].class_id
    NUM_SAMPLES = args['experiment'].samples
    size = tuple(args['experiment'].size)
    DATA_DIR = osp.join(DATA_DIR, 'Final_Training', 'Images', f'{CLASS}'.rjust(5, '0'))
    imgs = os.listdir(DATA_DIR)
    imgs = [ osp.join(DATA_DIR, img) for img in imgs ]

    random_samples = random.choices(imgs, k=NUM_SAMPLES)

    for i, img in enumerate(random_samples):
        if img.endswith('.csv'):
            continue
        image = cv2.imread(img)
        image = cv2.resize(image, size)
        image = load_augments(args['augmentations'], top=1)(image=image)

        cv2.imwrite(img[:-4]+'_aug'+img[-4:], image)
    print('Total Samples', len(set(random_samples)))
Beispiel #9
0
import os
import numpy as np
import torch
from torch.autograd import Variable
from collections import OrderedDict
from subprocess import call
import fractions
def lcm(a,b): return abs(a * b)/fractions.gcd(a,b) if a and b else 0

from options.train_options import TrainOptions
from data.data_loader import CreateDataLoader
from models.models import create_model
import util.util as util
from util.visualizer import Visualizer

opt_obj = TrainOptions()
opt = opt_obj.parse()
iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')
if opt.continue_train:
    try:
        start_epoch, epoch_iter = np.loadtxt(iter_path , delimiter=',', dtype=int)
    except:
        start_epoch, epoch_iter = 1, 0
    print('Resuming from epoch %d at iteration %d' % (start_epoch, epoch_iter))        
else:    
    start_epoch, epoch_iter = 1, 0

opt.print_freq = lcm(opt.print_freq, opt.batchSize)    
if opt.debug:
    opt.display_freq = 1
    opt.print_freq = 1
Beispiel #10
0
def main():
    cfg = TrainOptions().parse()  # get training options
    cfg.NUM_GPUS = torch.cuda.device_count()
    cfg.batch_size = int(cfg.batch_size / max(1, cfg.NUM_GPUS))
    cfg.phase = 'train'
    launch_job(cfg=cfg, init_method=cfg.init_method, func=train)
from options.train_options import TrainOptions
from models import create_model
from util.visualizer import save_images
from util import html
from PIL import Image

import string
import torch
import torchvision
import torchvision.transforms as transforms
import coremltools as ct

from util import util
import numpy as np

opt = TrainOptions().gather_options()
opt.isTrain = True
opt.name = "siggraph_caffemodel"
opt.mask_cent = 0
# opt.name = "siggraph_retrained"
opt.gpu_ids = []
opt.load_model = True
opt.num_threads = 1  # test code only supports num_threads = 1
opt.batch_size = 1  # test code only supports batch_size = 1
opt.display_id = -1  # no visdom display
opt.phase = 'val'
opt.dataroot = './dataset/ilsvrc2012/%s/' % opt.phase
opt.serial_batches = True
opt.aspect_ratio = 1.

# process opt.suffix
Beispiel #12
0
### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
import time
from collections import OrderedDict
from options.train_options import TrainOptions
from data.data_loader import CreateFaceConDataLoader
from models.models import create_model
import util.util as util
from util.visualizer import Visualizer
import os
import numpy as np
import torch
from torch.autograd import Variable
import tensorboardX
import random

opt = TrainOptions().parse()
opt_test = TrainOptions().parse()
opt_test.phase = 'val'
opt_test.nThreads = 1
opt_test.batchSize = 1
opt_test.serial_batches = False
opt_test.no_flip = True
data_loader_test = CreateFaceConDataLoader(opt_test)
dataset_test_ = data_loader_test.load_data()
dataset_test = dataset_test_.dataset
'''
for i, data in enumerate(dataset_test_):
    print(i)
    dataset_test.append(data)
    if (i > 10000):
        break
Beispiel #13
0
def main():
    opt = TrainOptions().parse()
    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('#training images = %d' % dataset_size)

    # read pix2pix/PAN moodel
    if opt.model == 'pix2pix':
        assert (opt.dataset_mode == 'aligned')
        from models.pix2pix_model import Pix2PixModel
        model = Pix2PixModel()
        model.initialize(opt)
    elif opt.model == 'pan':
        from models.pan_model import PanModel
        model = PanModel()
        model.initialize(opt)

    total_steps = 0

    batch_size = opt.batchSize
    print_freq = opt.print_freq
    epoch_count = opt.epoch_count
    niter = opt.niter
    niter_decay = opt.niter_decay
    display_freq = opt.display_freq
    save_latest_freq = opt.save_latest_freq
    save_epoch_freq = opt.save_epoch_freq

    for epoch in range(epoch_count, niter + niter_decay + 1):
        epoch_start_time = time.time()
        epoch_iter = 0

        for i, data in enumerate(dataset):
            # data --> (1, 3, 256, 256)
            iter_start_time = time.time()
            total_steps += batch_size
            epoch_iter += batch_size
            model.set_input(data)
            model.optimize_parameters()

            if total_steps % print_freq == 0:
                errors = model.get_current_errors()
                t = (time.time() - iter_start_time) / batch_size

                message = '(epoch: %d, iters: %d, time: %.3f) ' % (
                    epoch, epoch_iter, t)
                for k, v in errors.items():
                    message += '%s: %.3f ' % (k, v)
                print(message)

            # save latest weights
            if total_steps % save_latest_freq == 0:
                print('saving the latest model (epoch %d, total_steps %d)' %
                      (epoch, total_steps))
                model.save('latest')

        # save weights periodicaly
        if epoch % save_epoch_freq == 0:
            print('saving the model at the end of epoch %d, iters %d' %
                  (epoch, total_steps))
            model.save('latest')
            model.save(epoch)

        print('End of epoch %d / %d \t Time Taken: %d sec' %
              (epoch, niter + niter_decay, time.time() - epoch_start_time))
        model.update_learning_rate()
Beispiel #14
0
def main():
    opt = TrainOptions().parse()

    # Determine validation step options that might differ from training
    if opt.data == 'KTH':
        val_pick_mode = 'Slide'
        val_gpu_ids = [opt.gpu_ids[0]]
        val_batch_size = 1
    elif opt.data in ['UCF', 'HMDB51', 'S1M']:
        val_pick_mode = 'First'
        val_gpu_ids = opt.gpu_ids
        val_batch_size = opt.batch_size / 2
    else:
        raise ValueError('Dataset [%s] not recognized.' % opt.data)

    expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
    makedir(expr_dir)
    tb_dir = os.path.join(opt.tensorboard_dir, opt.name)
    makedir(tb_dir)

    file_name = os.path.join(expr_dir, 'train_opt.txt')
    with open(file_name, 'wt') as opt_file:
        listopt(opt, opt_file)

    log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
    print('after reading options')
    include_following = (opt.model_type != 'mcnet')
    data_loader = CustomDataLoader(opt.data, opt.c_dim, opt.dataroot,
                                   opt.textroot, opt.video_list, opt.K, opt.T,
                                   opt.backwards, opt.flip, opt.pick_mode,
                                   opt.image_size, include_following, opt.skip,
                                   opt.F, opt.batch_size, opt.serial_batches,
                                   opt.nThreads)
    print(data_loader.name())
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('# training videos = %d' % dataset_size)

    env = create_environment(
        opt.model_type, opt.gf_dim, opt.c_dim, opt.gpu_ids, True,
        opt.checkpoints_dir, opt.name, opt.K, opt.T, opt.F, opt.image_size,
        opt.batch_size, opt.which_update, opt.comb_type, opt.shallow, opt.ks,
        opt.num_block, opt.layers, opt.kf_dim, opt.enable_res, opt.rc_loc,
        opt.no_adversarial, opt.alpha, opt.beta, opt.D_G_switch, opt.margin,
        opt.lr, opt.beta1, opt.sn, opt.df_dim, opt.Ip, opt.continue_train,
        opt.comb_loss)

    total_updates = env.start_update
    writer = SummaryWriter(log_dir=os.path.join(opt.tensorboard_dir, opt.name))

    while True:
        for data in dataset:
            iter_start_time = time.time()

            # Enable losses on intermediate and final predictions partway through training
            if total_updates >= opt.inter_sup_update:
                env.enable_inter_loss()
            if total_updates >= opt.final_sup_update:
                env.enable_final_loss()

            # Update model
            total_updates += 1
            env.set_inputs(data)
            env.optimize_parameters()

            if total_updates % opt.print_freq == 0:
                errors = env.get_current_errors()
                t = (time.time() - iter_start_time) / opt.batch_size
                writer.add_scalar('iter_time', t, total_updates)
                for key in errors.keys():
                    writer.add_scalar('loss/%s' % (key), errors[key],
                                      total_updates)
                print_current_errors(log_name, total_updates, errors, t)

            if total_updates % opt.display_freq == 0:
                visuals = env.get_current_visuals()
                grid = visual_grid(visuals, opt.K, opt.T)
                writer.add_image('current_batch', grid, total_updates)

            if total_updates % opt.save_latest_freq == 0:
                print('saving the latest model (update %d)' % total_updates)
                env.save('latest', total_updates)
                env.save(total_updates, total_updates)

            if total_updates % opt.validate_freq == 0:
                psnr_plot, ssim_plot, grid = val(
                    opt.c_dim, opt.data, opt.T * 2, opt.dataroot, opt.textroot,
                    'val_data_list.txt', opt.K, opt.backwards, opt.flip,
                    val_pick_mode, opt.image_size, val_gpu_ids, opt.model_type,
                    opt.skip, opt.F, val_batch_size, True, opt.nThreads,
                    opt.gf_dim, False, opt.checkpoints_dir, opt.name,
                    opt.no_adversarial, opt.alpha, opt.beta, opt.D_G_switch,
                    opt.margin, opt.lr, opt.beta1, opt.sn, opt.df_dim, opt.Ip,
                    opt.comb_type, opt.comb_loss, opt.shallow, opt.ks,
                    opt.num_block, opt.layers, opt.kf_dim, opt.enable_res,
                    opt.rc_loc, opt.continue_train, 'latest')
                writer.add_image('psnr', psnr_plot, total_updates)
                writer.add_image('ssim', ssim_plot, total_updates)
                writer.add_image('samples', grid, total_updates)

            if total_updates >= opt.max_iter:
                env.save('latest', total_updates)
                break

        if total_updates >= opt.max_iter:
            break
Beispiel #15
0
def main():

    opt = TrainOptions()
    args = opt.initialize()

    _t = {'iter time': Timer()}

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)
        os.makedirs(os.path.join(args.snapshot_dir, 'logs'))
    opt.print_options(args)

    sourceloader, targetloader = CreateActSrcDataLoader(
        args), CreateActTrgDataLoader(args, 'train')
    testloader = CreateActTrgDataLoader(args, 'test')
    targetloader_iter, sourceloader_iter = iter(targetloader), iter(
        sourceloader)

    model, optimizer = CreateModel(args)
    model_D, optimizer_D = CreateDiscriminator(args)

    start_iter = 0
    if args.restore_from is not None:
        start_iter = int(args.restore_from.rsplit('/', 1)[1].rsplit('_')[1])

    train_writer = tensorboardX.SummaryWriter(
        os.path.join(args.snapshot_dir, "logs"))

    bce_loss = torch.nn.BCEWithLogitsLoss()

    cudnn.enabled = True
    cudnn.benchmark = True
    model.train()
    model.cuda()
    model_D.train()
    model_D.cuda()
    loss = [
        'loss_src', 'loss_trg', 'loss_D_trg_fake', 'loss_D_src_real',
        'loss_D_trg_real', 'eval_loss'
    ]
    _t['iter time'].tic()
    best_loss_eval = None
    best_step = 0
    eval_loss = np.array([0])
    for i in range(start_iter, args.num_steps):

        model.module.adjust_learning_rate(args, optimizer, i)
        model_D.module.adjust_learning_rate(args, optimizer_D, i)

        optimizer.zero_grad()
        optimizer_D.zero_grad()
        for param in model_D.parameters():
            param.requires_grad = False

        try:
            src_img, src_lbl, _, _ = next(sourceloader_iter)
        except StopIteration:
            sourceloader_iter = iter(sourceloader)
            src_img, src_lbl, _, _ = next(sourceloader_iter)
        src_img, src_lbl = Variable(src_img).cuda(), Variable(
            src_lbl.long()).cuda()
        src_score, loss_src = model(src_img, lbl=src_lbl)
        loss_src.mean().backward()

        try:
            trg_img, trg_lbl, _, _ = next(targetloader_iter)
        except StopIteration:
            targetloader_iter = iter(targetloader)
            trg_img, trg_lbl, _, _ = next(targetloader_iter)
        trg_img, trg_lbl = Variable(trg_img).cuda(), Variable(
            trg_lbl.long()).cuda()
        trg_score, loss_trg = model(trg_img, lbl=trg_lbl)

        outD_trg, loss_D_trg_fake = model_D(F.softmax(trg_score, dim=1),
                                            0)  # do not apply softmax

        loss_trg = args.lambda_adv_target * loss_D_trg_fake + loss_trg
        loss_trg.mean().backward()

        for param in model_D.parameters():
            param.requires_grad = True

        src_score, trg_score = src_score.detach(), trg_score.detach()

        outD_src, model_D_loss = model_D(F.softmax(src_score, dim=1),
                                         0)  # do not apply softmax

        loss_D_src_real = model_D_loss / 2
        loss_D_src_real.mean().backward()

        outD_trg, model_D_loss = model_D(F.softmax(trg_score, dim=1),
                                         1)  # do not apply softmax

        loss_D_trg_real = model_D_loss / 2
        loss_D_trg_real.mean().backward()

        optimizer.step()
        optimizer_D.step()

        for m in loss:
            train_writer.add_scalar(m, eval(m).mean(), i + 1)

        if (i + 1) % args.save_pred_every == 0:
            with torch.no_grad():
                model.eval()
                eval_loss = 0
                for test_img, test_lbl, _, _ in testloader:
                    test_score, loss_test = model(test_img, lbl=test_lbl)
                    eval_loss += loss_test.mean().item() * test_img.size(0)
                eval_loss /= len(testloader.dataset)
                if best_loss_eval == None or eval_loss < best_loss_eval:
                    best_loss_eval = eval_loss
                    best_step = i + 1
                print('taking snapshot ... eval_loss: {}'.format(eval_loss))
                torch.save(
                    model.module.state_dict(),
                    os.path.join(args.snapshot_dir,
                                 str(i + 1) + '.pth'))
                eval_loss = np.array([eval_loss])

        if (i + 1) % args.print_freq == 0:
            _t['iter time'].toc(average=False)
            print('[it %d][src loss %.4f][lr %.4f][%.2fs]' % \
                    (i + 1, loss_src.mean().data, optimizer.param_groups[0]['lr']*10000, _t['iter time'].diff))
            if i + 1 > args.num_steps_stop:
                print('finish training')
                break
            _t['iter time'].tic()
Beispiel #16
0
def train():
    opt = TrainOptions().parse()
    if opt.debug:
        opt.display_freq = 1
        opt.print_freq = 1
        opt.nThreads = 1

    ### initialize dataset
    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    if opt.dataset_mode == 'pose':
        print('#training frames = %d' % dataset_size)
    else:
        print('#training videos = %d' % dataset_size)

    ### initialize models
    modelG, modelD, flowNet = create_model(opt)
    visualizer = Visualizer(opt)

    iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')
    ### if continue training, recover previous states
    if opt.continue_train:
        try:
            start_epoch, epoch_iter = np.loadtxt(iter_path,
                                                 delimiter=',',
                                                 dtype=int)
        except:
            start_epoch, epoch_iter = 1, 0
        print('Resuming from epoch %d at iteration %d' %
              (start_epoch, epoch_iter))
        if start_epoch > opt.niter:
            modelG.module.update_learning_rate(start_epoch - 1)
            modelD.module.update_learning_rate(start_epoch - 1)
        if (opt.n_scales_spatial > 1) and (opt.niter_fix_global != 0) and (
                start_epoch > opt.niter_fix_global):
            modelG.module.update_fixed_params()
        if start_epoch > opt.niter_step:
            data_loader.dataset.update_training_batch(
                (start_epoch - 1) // opt.niter_step)
            modelG.module.update_training_batch(
                (start_epoch - 1) // opt.niter_step)
    else:
        start_epoch, epoch_iter = 1, 0

    ### set parameters
    n_gpus = opt.n_gpus_gen // opt.batchSize  # number of gpus used for generator for each batch
    tG, tD = opt.n_frames_G, opt.n_frames_D
    tDB = tD * opt.output_nc
    s_scales = opt.n_scales_spatial
    t_scales = opt.n_scales_temporal
    input_nc = 1 if opt.label_nc != 0 else opt.input_nc
    output_nc = opt.output_nc

    opt.print_freq = lcm(opt.print_freq, opt.batchSize)
    total_steps = (start_epoch - 1) * dataset_size + epoch_iter
    total_steps = total_steps // opt.print_freq * opt.print_freq

    ### real training starts here
    for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
        for idx, data in enumerate(dataset, start=epoch_iter):
            if total_steps % opt.print_freq == 0:
                iter_start_time = time.time()
            total_steps += opt.batchSize
            epoch_iter += opt.batchSize

            # whether to collect output images
            save_fake = total_steps % opt.display_freq == 0

            _, n_frames_total, height, width = data['B'].size(
            )  # n_frames_total = n_frames_load * n_loadings + tG - 1
            n_frames_total = n_frames_total // opt.output_nc
            n_frames_load = opt.max_frames_per_gpu * n_gpus  # number of total frames loaded into GPU at a time for each batch
            n_frames_load = min(n_frames_load, n_frames_total - tG + 1)
            t_len = n_frames_load + tG - 1  # number of loaded frames plus previous frames

            fake_B_last = None  # the last generated frame from previous training batch (which becomes input to the next batch)
            real_B_all, fake_B_all, flow_ref_all, conf_ref_all = None, None, None, None  # all real/generated frames so far
            real_B_skipped, fake_B_skipped = [None] * t_scales, [
                None
            ] * t_scales  # temporally subsampled frames
            flow_ref_skipped, conf_ref_skipped = [None] * t_scales, [
                None
            ] * t_scales  # temporally subsampled flows

            for i in range(0, n_frames_total - t_len + 1, n_frames_load):
                # 5D tensor: batchSize, # of frames, # of channels, height, width
                input_A = Variable(
                    data['A'][:, i * input_nc:(i + t_len) * input_nc,
                              ...]).view(-1, t_len, input_nc, height, width)
                input_B = Variable(
                    data['B'][:, i * output_nc:(i + t_len) * output_nc,
                              ...]).view(-1, t_len, output_nc, height, width)
                inst_A = Variable(data['inst'][:, i:i + t_len, ...]).view(
                    -1, t_len, 1, height,
                    width) if len(data['inst'].size()) > 2 else None

                ###################################### Forward Pass ##########################
                ####### generator
                fake_B, fake_B_raw, flow, weight, real_A, real_Bp, fake_B_last = modelG(
                    input_A, input_B, inst_A, fake_B_last)

                if i == 0:
                    fake_B_first = fake_B[
                        0, 0]  # the first generated image in this sequence
                real_B_prev, real_B = real_Bp[:, :
                                              -1], real_Bp[:,
                                                           1:]  # the collection of previous and current real frames

                ####### discriminator
                ### individual frame discriminator
                flow_ref, conf_ref = flowNet(
                    real_B[:, :, :3, ...],
                    real_B_prev[:, :, :3,
                                ...])  # reference flows and confidences
                fake_B_prev = real_B_prev[:, 0:
                                          1] if fake_B_last is None else fake_B_last[
                                              0][:, -1:]
                if fake_B.size()[1] > 1:
                    fake_B_prev = torch.cat(
                        [fake_B_prev, fake_B[:, :-1].detach()], dim=1)

                losses = modelD(
                    0,
                    reshape([
                        real_B, fake_B, fake_B_raw, real_A, real_B_prev,
                        fake_B_prev, flow, weight, flow_ref, conf_ref
                    ]))
                losses = [
                    torch.mean(x) if x is not None else 0 for x in losses
                ]
                loss_dict = dict(zip(modelD.module.loss_names, losses))

                ### temporal discriminator
                loss_dict_T = []
                # get skipped frames for each temporal scale
                if t_scales > 0:
                    real_B_all, real_B_skipped = get_skipped_frames(
                        real_B_all, real_B, t_scales, tD)
                    fake_B_all, fake_B_skipped = get_skipped_frames(
                        fake_B_all, fake_B, t_scales, tD)
                    flow_ref_all, conf_ref_all, flow_ref_skipped, conf_ref_skipped = get_skipped_flows(
                        flowNet, flow_ref_all, conf_ref_all, real_B_skipped,
                        flow_ref, conf_ref, t_scales, tD)

                # run discriminator for each temporal scale
                for s in range(t_scales):
                    if real_B_skipped[s] is not None and real_B_skipped[
                            s].size()[1] == tD:
                        losses = modelD(s + 1, [
                            real_B_skipped[s], fake_B_skipped[s],
                            flow_ref_skipped[s], conf_ref_skipped[s]
                        ])
                        losses = [
                            torch.mean(x) if not isinstance(x, int) else x
                            for x in losses
                        ]
                        loss_dict_T.append(
                            dict(zip(modelD.module.loss_names_T, losses)))

                # collect losses
                loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5
                loss_G = loss_dict['G_GAN'] + loss_dict[
                    'G_GAN_Feat'] + loss_dict['G_VGG']
                loss_G += loss_dict['G_Warp'] + loss_dict[
                    'F_Flow'] + loss_dict['F_Warp'] + loss_dict['W']
                if opt.add_face_disc:
                    loss_G += loss_dict['G_f_GAN'] + loss_dict['G_f_GAN_Feat']
                    loss_D += (loss_dict['D_f_fake'] +
                               loss_dict['D_f_real']) * 0.5

                # collect temporal losses
                loss_D_T = []
                t_scales_act = min(t_scales, len(loss_dict_T))
                for s in range(t_scales_act):
                    loss_G += loss_dict_T[s]['G_T_GAN'] + loss_dict_T[s][
                        'G_T_GAN_Feat'] + loss_dict_T[s]['G_T_Warp']
                    loss_D_T.append((loss_dict_T[s]['D_T_fake'] +
                                     loss_dict_T[s]['D_T_real']) * 0.5)

                ###################################### Backward Pass #################################
                optimizer_G = modelG.module.optimizer_G
                optimizer_D = modelD.module.optimizer_D
                # update generator weights
                optimizer_G.zero_grad()
                loss_G.backward()
                optimizer_G.step()

                # update discriminator weights
                # individual frame discriminator
                optimizer_D.zero_grad()
                loss_D.backward()
                optimizer_D.step()
                # temporal discriminator
                for s in range(t_scales_act):
                    optimizer_D_T = getattr(modelD.module,
                                            'optimizer_D_T' + str(s))
                    optimizer_D_T.zero_grad()
                    loss_D_T[s].backward()
                    optimizer_D_T.step()

            if opt.debug:
                call([
                    "nvidia-smi", "--format=csv",
                    "--query-gpu=memory.used,memory.free"
                ])

            ############## Display results and errors ##########
            ### print out errors
            if total_steps % opt.print_freq == 0:
                t = (time.time() - iter_start_time) / opt.print_freq
                errors = {
                    k: v.data.item() if not isinstance(v, int) else v
                    for k, v in loss_dict.items()
                }
                for s in range(len(loss_dict_T)):
                    errors.update({
                        k + str(s):
                        v.data.item() if not isinstance(v, int) else v
                        for k, v in loss_dict_T[s].items()
                    })
                visualizer.print_current_errors(epoch, epoch_iter, errors, t)
                visualizer.plot_current_errors(errors, total_steps)

            ### display output images
            if save_fake:
                if opt.label_nc != 0:
                    input_image = util.tensor2label(real_A[0, -1],
                                                    opt.label_nc)
                elif opt.dataset_mode == 'pose':
                    input_image = util.tensor2im(real_A[0, -1, :3],
                                                 normalize=False)
                    if real_A.size()[2] == 6:
                        input_image2 = util.tensor2im(real_A[0, -1, 3:],
                                                      normalize=False)
                        input_image[input_image2 != 0] = input_image2[
                            input_image2 != 0]
                else:
                    c = 3 if opt.input_nc == 3 else 1
                    input_image = util.tensor2im(real_A[0, -1, :c],
                                                 normalize=False)
                if opt.use_instance:
                    edges = util.tensor2im(real_A[0, -1, -1:, ...],
                                           normalize=False)
                    input_image += edges[:, :, np.newaxis]

                if opt.add_face_disc:
                    ys, ye, xs, xe = modelD.module.get_face_region(real_A[0,
                                                                          -1:])
                    if ys is not None:
                        input_image[ys, xs:xe, :] = input_image[
                            ye, xs:xe, :] = input_image[
                                ys:ye, xs, :] = input_image[ys:ye, xe, :] = 255

                visual_list = [
                    ('input_image', util.tensor2im(real_A[0, -1])),
                    ('fake_image', util.tensor2im(fake_B[0, -1])),
                    ('fake_first_image', util.tensor2im(fake_B_first)),
                    ('fake_raw_image', util.tensor2im(fake_B_raw[0, -1])),
                    ('real_image', util.tensor2im(real_B[0, -1])),
                    ('flow_ref', util.tensor2flow(flow_ref[0, -1])),
                    ('conf_ref',
                     util.tensor2im(conf_ref[0, -1], normalize=False))
                ]
                if flow is not None:
                    visual_list += [('flow', util.tensor2flow(flow[0, -1])),
                                    ('weight',
                                     util.tensor2im(weight[0, -1],
                                                    normalize=False))]
                visuals = OrderedDict(visual_list)
                visualizer.display_current_results(visuals, epoch, total_steps)

            ### save latest model
            if total_steps % opt.save_latest_freq == 0:
                visualizer.vis_print(
                    'saving the latest model (epoch %d, total_steps %d)' %
                    (epoch, total_steps))
                modelG.module.save('latest')
                modelD.module.save('latest')
                np.savetxt(iter_path, (epoch, epoch_iter),
                           delimiter=',',
                           fmt='%d')

            if epoch_iter > dataset_size - opt.batchSize:
                epoch_iter = 0
                break

        # end of epoch
        iter_end_time = time.time()
        visualizer.vis_print('End of epoch %d / %d \t Time Taken: %d sec' %
                             (epoch, opt.niter + opt.niter_decay,
                              time.time() - epoch_start_time))

        ### save model for this epoch
        if epoch % opt.save_epoch_freq == 0:
            visualizer.vis_print(
                'saving the model at the end of epoch %d, iters %d' %
                (epoch, total_steps))
            modelG.module.save('latest')
            modelD.module.save('latest')
            modelG.module.save(epoch)
            modelD.module.save(epoch)
            np.savetxt(iter_path, (epoch + 1, 0), delimiter=',', fmt='%d')

        ### linearly decay learning rate after certain iterations
        if epoch > opt.niter:
            modelG.module.update_learning_rate(epoch)
            modelD.module.update_learning_rate(epoch)

        ### gradually grow training sequence length
        if (epoch % opt.niter_step) == 0:
            data_loader.dataset.update_training_batch(epoch // opt.niter_step)
            modelG.module.update_training_batch(epoch // opt.niter_step)

        ### finetune all scales
        if (opt.n_scales_spatial > 1) and (opt.niter_fix_global != 0) and (
                epoch == opt.niter_fix_global):
            modelG.module.update_fixed_params()
Beispiel #17
0
def train_function(params):
    """
    Example:
        Train a CycleGAN model:
            python train.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan
        Train a pix2pix model:
            python train.py --dataroot ./datasets/maps --name maps_pix2pix --model pix2pix --direction BtoA
        Train a S2OMGAN model:
            python train.py --dataroot ./datasets/maps --name maps_somgan --model somgan
    """

    try:
        from torch.utils.tensorboard import SummaryWriter
    except ImportError:
        from tensorboardX import SummaryWriter

# if __name__ == '__main__':

# region get options from a json file
# get_opt_json()
# endregion

    sys.argv = params

    opt = TrainOptions().parse()  # get training options
    opt.dataroot = opt.TRAIN_FILE_PATH

    opt.dataset_mode = 'aligned'
    if os.path.exists(opt.dataroot + "/train"):
        datasetP = create_dataset(opt)  # create a paired dataset
    else:
        datasetP = []
    datasetP_size = len(datasetP)  # get the number of images in the dataset.
    print('The number of paired training images = %d' % datasetP_size)

    opt.dataset_mode = 'unaligned'
    if os.path.exists(opt.dataroot +
                      "/trainA") and os.path.exists(opt.dataroot + "/trainB"):
        datasetU = create_dataset(opt)  # create a unpaired dataset
    else:
        datasetU = []
    datasetU_size = len(datasetU)  # get the number of images in the dataset.
    print('The number of unpaired training images = %d' % datasetU_size)

    model = create_model(
        opt)  # create a model given opt.model and other options
    model.setup(
        opt)  # regular setup: load and print networks; create schedulers
    visualizer = Visualizer(
        opt)  # create a visualizer that display/save images and plots
    total_iters = 0  # the total number of training iterations

    tb_writer = SummaryWriter(opt.LOG_PATH)

    for epoch in range(
            opt.epoch_count, opt.niter + opt.niter_decay + 1
    ):  # outer loop for different epochs; we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>
        epoch_start_time = time.time()  # timer for entire epoch
        iter_data_time = time.time()  # timer for data loading per iteration
        epoch_iter = 0  # the number of training iterations in current epoch, reset to 0 every epoch

        # unpaired part - CycleGAN mode.
        for i, data in enumerate(datasetU):  # inner loop within one epoch
            iter_start_time = time.time(
            )  # timer for computation per iteration
            if total_iters % opt.print_freq == 0:
                t_data = iter_start_time - iter_data_time
            visualizer.reset()
            total_iters += opt.batch_size
            epoch_iter += opt.batch_size
            model.set_input(
                data)  # unpack data from dataset and apply preprocessing
            model.optimize_parameters(
            )  # calculate loss functions, get gradients, update network weights

            #             if total_iters % opt.display_freq == 0:   # display images on visdom and save images to a HTML file
            #                 save_result = total_iters % opt.update_html_freq == 0
            #                 model.compute_visuals()
            #                 visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)

            if total_iters % opt.print_freq == 0:  # print training losses and save logging information to the disk
                losses = model.get_current_losses()
                t_comp = (time.time() - iter_start_time) / opt.batch_size
                visualizer.print_current_losses(epoch, epoch_iter, losses,
                                                t_comp, t_data)
                if opt.display_id > 0:
                    visualizer.plot_current_losses(
                        epoch,
                        float(epoch_iter) / datasetU_size, losses)

                for name, val in losses.items():
                    tb_writer.add_scalar("loss_" + name, val, total_iters)

#             if total_iters % opt.save_latest_freq == 0:   # cache our latest model every <save_latest_freq> iterations
#                 print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters))
#                 save_suffix = 'iter_%d' % total_iters if opt.save_by_iter else 'latest'
#                 model.save_networks(save_suffix)

            iter_data_time = time.time()

        # paired part - pix2pix mode.
        for i, data in enumerate(datasetP):  # inner loop within one epoch
            iter_start_time = time.time(
            )  # timer for computation per iteration
            if total_iters % opt.print_freq == 0:
                t_data = iter_start_time - iter_data_time
            visualizer.reset()
            total_iters += opt.batch_size
            epoch_iter += opt.batch_size
            model.set_input(
                data)  # unpack data from dataset and apply preprocessing
            model.optimize_parameters(
                lambda_paired_loss=1.,
                epoch_ratio=(1. * epoch / (opt.niter + opt.niter_decay))
            )  # calculate loss functions, get gradients, update network weights

            #             if total_iters % opt.display_freq == 0:   # display images on visdom and save images to a HTML file
            #                 save_result = total_iters % opt.update_html_freq == 0
            #                 model.compute_visuals()
            #                 visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)

            if total_iters % opt.print_freq == 0:  # print training losses and save logging information to the disk
                losses = model.get_current_losses()
                t_comp = (time.time() - iter_start_time) / opt.batch_size
                visualizer.print_current_losses(epoch, epoch_iter, losses,
                                                t_comp, t_data)
                if opt.display_id > 0:
                    visualizer.plot_current_losses(
                        epoch,
                        float(epoch_iter) / datasetP_size, losses)

                for name, val in losses.items():
                    tb_writer.add_scalar("loss_" + name, val, total_iters)

#             if total_iters % opt.save_latest_freq == 0:   # cache our latest model every <save_latest_freq> iterations
#                 print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters))
#                 save_suffix = 'iter_%d' % total_iters if opt.save_by_iter else 'latest'
#                 model.save_networks(save_suffix)

            iter_data_time = time.time()

#         if epoch % opt.save_epoch_freq == 0:              # cache our model every <save_epoch_freq> epochs
#             print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters))
#             model.save_networks('latest')
#             model.save_networks(epoch)

        print('End of epoch %d / %d \t Time Taken: %d sec' %
              (epoch, opt.niter + opt.niter_decay,
               time.time() - epoch_start_time))
        model.update_learning_rate(
        )  # update learning rates at the end of every epoch.


# save net for platform.
    net = getattr(model, 'netG_A')
    if len(model.gpu_ids) > 0 and torch.cuda.is_available():
        torch.save(net.module.cpu().state_dict(), opt.CHECKPOINT_PATH)
        net.cuda(model.gpu_ids[0])
    else:
        torch.save(net.cpu().state_dict(), opt.CHECKPOINT_PATH)
Beispiel #18
0
import time
from options.train_options import TrainOptions
from data import CreateDataLoader
from models import create_model
from util.util import confusion_matrix, getScores, tensor2labelim, tensor2im, print_current_losses
import numpy as np
import random
import torch
import cv2
from tensorboardX import SummaryWriter

if __name__ == '__main__':
    train_opt = TrainOptions().parse()

    np.random.seed(train_opt.seed)
    random.seed(train_opt.seed)
    torch.manual_seed(train_opt.seed)
    torch.cuda.manual_seed(train_opt.seed)

    train_data_loader = CreateDataLoader(train_opt)
    train_dataset = train_data_loader.load_data()
    train_dataset_size = len(train_data_loader)
    print('#training images = %d' % train_dataset_size)

    valid_opt = TrainOptions().parse()
    valid_opt.phase = 'val'
    valid_opt.batch_size = 1
    valid_opt.num_threads = 1
    valid_opt.serial_batches = True
    valid_opt.isTrain = False
    valid_data_loader = CreateDataLoader(valid_opt)
Beispiel #19
0
def main():
    opt = TrainOptions().parse()
    #read training files
    f = open(os.path.join(opt.txtroot, opt.video_list), 'r')
    trainfiles = f.readlines()
    print('video num: %s' % len(trainfiles))

    #create model
    model = create_model(opt)
    total_steps = 0
    writer = SummaryWriter(log_dir=os.path.join(opt.tensorboard_dir, opt.name))
    with Parallel(n_jobs=opt.batch_size) as parallel:
        for epoch in range(opt.start_epoch, opt.nepoch + opt.nepoch_decay + 1):
            mini_batches = util.get_minibatches_idx(len(trainfiles),
                                                    opt.batch_size,
                                                    shuffle=True)

            for _, batchidx in mini_batches:
                if len(batchidx) == opt.batch_size:
                    inputs_batch = np.zeros((opt.batch_size, 1, opt.image_size,
                                             opt.image_size, opt.K + opt.T),
                                            dtype='float32')

                    Ts = np.repeat(np.array([opt.T]), opt.batch_size, axis=0)
                    Ks = np.repeat(np.array([opt.K]), opt.batch_size, axis=0)
                    paths = np.repeat(opt.data_root, opt.batch_size, axis=0)
                    tfiles = np.array(trainfiles)[batchidx]
                    shapes = np.repeat(np.array([opt.image_size]),
                                       opt.batch_size,
                                       axis=0)
                    output = parallel(
                        delayed(util.load_kth_data)(f, p, image_size, k, t)
                        for f, p, image_size, k, t in zip(
                            tfiles, paths, shapes, Ks, Ts))
                    output = torch.stack(output, dim=0)
                    model.set_inputs(output)
                    model.optimize_parameters()
                    total_steps += 1

                    if total_steps % opt.print_freq == 0:
                        print('total_steps % opt.print_freq == 0')
                        errors = model.get_current_errors()

                        for key in errors.keys():
                            writer.add_scalar('loss/%s' % (key), errors[key],
                                              total_steps / opt.batch_size)

                        util.print_current_errors(epoch, total_steps, errors,
                                                  opt.checkpoints_dir,
                                                  opt.name)
                    if total_steps % opt.display_freq == 0:
                        print('total_steps % opt.display_freq == 0')
                        visuals = model.get_current_visuals()
                        grid = util.visual_grid(visuals['seq_batch'],
                                                visuals['pred'], opt.K, opt.T)
                        writer.add_image('current_batch', grid,
                                         total_steps / opt.batch_size)
                    if total_steps % opt.save_latest_freq == 0:
                        print(
                            'saving the latest model (epoch %d, total_steps %d)'
                            % (epoch, total_steps))
                        model.save('latest', epoch)

        print("end training")
Beispiel #20
0
def train():
    opt = TrainOptions().parse()
    iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')
    if opt.continue_train:
        try:
            start_epoch, epoch_iter = np.loadtxt(iter_path,
                                                 delimiter=',',
                                                 dtype=int)
        except:
            start_epoch, epoch_iter = 1, 0
        # compute resume lr
        if start_epoch > opt.niter:
            lrd_unit = opt.lr / opt.niter_decay
            resume_lr = opt.lr - (start_epoch - opt.niter) * lrd_unit
            opt.lr = resume_lr
        print('Resuming from epoch %d at iteration %d' %
              (start_epoch, epoch_iter))
    else:
        start_epoch, epoch_iter = 1, 0

    opt.print_freq = lcm(opt.print_freq, opt.batchSize)
    if opt.debug:
        opt.display_freq = 2
        opt.print_freq = 2
        opt.niter = 3
        opt.niter_decay = 0
        opt.max_dataset_size = 1
        opt.valSize = 1

    ## Loading data
    # train data
    data_loader = CreateDataLoader(opt, isVal=False)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('# training images = %d' % dataset_size)
    # validation data
    data_loader = CreateDataLoader(opt, isVal=True)
    valset = data_loader.load_data()
    print('# validation images = %d' % len(data_loader))

    ## Loading model
    model = create_model(opt)
    visualizer = Visualizer(opt)
    if opt.fp16:
        from apex import amp
        model, [optimizer_G, optimizer_D
                ] = amp.initialize(model,
                                   [model.optimizer_G, model.optimizer_D],
                                   opt_level='O1')
        model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)
    else:
        optimizer_G, optimizer_D = model.module.optimizer_G, model.module.optimizer_D

    total_steps = (start_epoch - 1) * dataset_size + epoch_iter

    display_delta = total_steps % opt.display_freq
    print_delta = total_steps % opt.print_freq
    save_delta = total_steps % opt.save_latest_freq

    for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
        if epoch != start_epoch:
            epoch_iter = epoch_iter % dataset_size
        for i, data in enumerate(dataset, start=epoch_iter):
            if total_steps % opt.print_freq == print_delta:
                iter_start_time = time.time()

            total_steps += opt.batchSize
            epoch_iter += opt.batchSize

            # whether to collect output images
            save_fake = total_steps % opt.display_freq == display_delta

            ############## Forward Pass ######################
            model = model.train()
            losses, generated, metrics = model(data['A'],
                                               data['B'],
                                               data['geometry'],
                                               infer=False)

            # sum per device losses and metrics
            losses = [
                torch.mean(x) if not isinstance(x, int) else x for x in losses
            ]
            metric_dict = {k: torch.mean(v) for k, v in metrics.items()}
            loss_dict = dict(zip(model.module.loss_names, losses))

            # calculate final loss scalar
            loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5
            loss_G = loss_dict['G_GAN'] + opt.gan_feat_weight * loss_dict.get(
                'G_GAN_Feat', 0) + opt.vgg_weight * loss_dict.get('G_VGG', 0)

            ############### Backward Pass ####################
            # update generator weights
            optimizer_G.zero_grad()
            if opt.fp16:
                with amp.scale_loss(loss_G, optimizer_G) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss_G.backward()
            optimizer_G.step()

            # update discriminator weights
            optimizer_D.zero_grad()
            if opt.fp16:
                with amp.scale_loss(loss_D, optimizer_D) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss_D.backward()
            optimizer_D.step()

            ############## Display results and errors ##########
            ### print out errors
            if total_steps % opt.print_freq == print_delta:
                errors = {
                    k: v.data.item() if not isinstance(v, int) else v
                    for k, v in loss_dict.items()
                }
                metrics_ = {
                    k: v.data.item() if not isinstance(v, int) else v
                    for k, v in metric_dict.items()
                }
                t = (time.time() - iter_start_time) / opt.print_freq
                visualizer.print_current_errors(epoch, epoch_iter, errors, t)
                visualizer.plot_current_errors(errors, total_steps)
                visualizer.print_current_metrics(epoch, epoch_iter, metrics_,
                                                 t)
                visualizer.plot_current_metrics(metrics_, total_steps)
                #call(["nvidia-smi", "--format=csv", "--query-gpu=memory.used,memory.free"])

            ### display output images
            if save_fake:
                if opt.task_type == 'specular':
                    visuals = OrderedDict([
                        ('albedo', util.tensor2im(data['A'][0])),
                        ('generated',
                         util.tensor2im_exr(generated.data[0], type=1)),
                        ('GT', util.tensor2im_exr(data['B'][0], type=1))
                    ])
                elif opt.task_type == 'low':
                    visuals = OrderedDict([
                        ('albedo', util.tensor2im(data['A'][0])),
                        ('generated',
                         util.tensor2im_exr(generated.data[0], type=2)),
                        ('GT', util.tensor2im_exr(data['B'][0], type=2))
                    ])
                elif opt.task_type == 'high':
                    visuals = OrderedDict([
                        ('albedo', util.tensor2im(data['A'][0])),
                        ('generated',
                         util.tensor2im_exr(generated.data[0], type=3)),
                        ('GT', util.tensor2im_exr(data['B'][0], type=3))
                    ])
                visualizer.display_current_results(visuals, epoch, total_steps)

            ### save latest model
            if total_steps % opt.save_latest_freq == save_delta:
                print('saving the latest model (epoch %d, total_steps %d)' %
                      (epoch, total_steps))
                model.module.save('latest')
                np.savetxt(iter_path, (epoch, epoch_iter),
                           delimiter=',',
                           fmt='%d')

            if epoch_iter >= dataset_size:
                break

        # end of epoch
        print('End of epoch %d / %d \t Time Taken: %d sec' %
              (epoch, opt.niter + opt.niter_decay,
               time.time() - epoch_start_time))

        ###########################################################################################
        # validation at the end of each epoch
        val_start_time = time.time()
        metrics_val = []
        for _, val_data in enumerate(valset):
            model = model.eval()
            # model.half()
            generated, metrics = model(val_data['A'],
                                       val_data['B'],
                                       val_data['geometry'],
                                       infer=True)
            metric_dict = {k: torch.mean(v) for k, v in metrics.items()}
            metrics_ = {
                k: v.data.item() if not isinstance(v, int) else v
                for k, v in metric_dict.items()
            }
            metrics_val.append(metrics_)
        # Print out losses
        metrics_val = visualizer.mean4dict(metrics_val)
        t = (time.time() - val_start_time) / opt.print_freq
        visualizer.print_current_metrics(epoch,
                                         epoch_iter,
                                         metrics_val,
                                         t,
                                         isVal=True)
        visualizer.plot_current_metrics(metrics_val, total_steps, isVal=True)
        # visualization
        if opt.task_type == 'specular':
            visuals = OrderedDict([
                ('albedo', util.tensor2im(val_data['A'][0])),
                ('generated', util.tensor2im_exr(generated.data[0], type=1)),
                ('GT', util.tensor2im_exr(val_data['B'][0], type=1))
            ])
        if opt.task_type == 'low':
            visuals = OrderedDict([
                ('albedo', util.tensor2im(val_data['A'][0])),
                ('generated', util.tensor2im_exr(generated.data[0], type=2)),
                ('GT', util.tensor2im_exr(val_data['B'][0], type=2))
            ])
        if opt.task_type == 'high':
            visuals = OrderedDict([
                ('albedo', util.tensor2im(val_data['A'][0])),
                ('generated', util.tensor2im_exr(generated.data[0], type=3)),
                ('GT', util.tensor2im_exr(val_data['B'][0], type=3))
            ])
        visualizer.display_current_results(visuals, epoch, epoch, isVal=True)
        ###########################################################################################

        ### save model for this epoch
        if epoch % opt.save_epoch_freq == 0:
            print('saving the model at the end of epoch %d, iters %d' %
                  (epoch, total_steps))
            model.module.save('latest')
            model.module.save(epoch)
            np.savetxt(iter_path, (epoch + 1, 0), delimiter=',', fmt='%d')

        ### instead of only training the local enhancer, train the entire network after certain iterations
        if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global):
            model.module.update_fixed_params()

        ### linearly decay learning rate after certain iterations
        if epoch > opt.niter:
            model.module.update_learning_rate()
Beispiel #21
0
import time
from options.test_options import TestOptions
from options.train_options import TrainOptions
from data.data_loader import CreateDataLoader
from models.models import create_model
from util.visualizer import Visualizer

# train
opt_train = TrainOptions().parse()
data_loader_train = CreateDataLoader(opt_train)
dataset_train = data_loader_train.load_data()
dataset_size_train = len(data_loader_train)
print('#training images = %d' % dataset_size_train)

# test
opt_test = TestOptions().parse()
opt_test.nThreads = 1  # test code only supports nThreads = 1
opt_test.batchSize = 1  # test code only supports batchSize = 1
opt_test.serial_batches = False  # no shuffle
opt_test.no_flip = True  # no flip
opt_test.how_many = 100
data_loader_test = CreateDataLoader(opt_test)
dataset_test = data_loader_test.load_data()
dataset_size_test = len(data_loader_test)
print('#test images = %d' % dataset_size_test)

model = create_model(opt_train)
visualizer = Visualizer(opt_train)
total_steps = 0

for epoch in range(opt_train.epoch_count,
Beispiel #22
0
def main():
    opt = TrainOptions().parse()
    train_history = TrainHistory()
    checkpoint = Checkpoint()
    visualizer = Visualizer(opt)
    exp_dir = os.path.join(opt.exp_dir, opt.exp_id)
    log_name = opt.vis_env + 'log.txt'
    visualizer.log_name = os.path.join(exp_dir, log_name)
    os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_id
    # if opt.dataset == 'mpii':
    num_classes = 16
    # layer_num = 2
    net = create_cu_net(neck_size=4,
                        growth_rate=32,
                        init_chan_num=128,
                        class_num=num_classes,
                        layer_num=opt.layer_num,
                        order=1,
                        loss_num=opt.layer_num)
    # num1 = get_n_params(net)
    # num2 = get_n_trainable_params(net)
    # num3 = get_n_conv_params(net)
    # print 'number of params: ', num1
    # print 'number of trainalbe params: ', num2
    # print 'number of conv params: ', num3
    # torch.save(net.state_dict(), 'test-model-size.pth.tar')
    # exit()
    # device = torch.device("cuda:0")
    # net = net.to(device)
    net = torch.nn.DataParallel(net).cuda()
    global bin_op
    bin_op = BinOp(net)
    optimizer = torch.optim.RMSprop(net.parameters(),
                                    lr=opt.lr,
                                    alpha=0.99,
                                    eps=1e-8,
                                    momentum=0,
                                    weight_decay=0)
    """optionally resume from a checkpoint"""
    if opt.resume_prefix != '':
        # if 'pth' in opt.resume_prefix:
        #     trunc_index = opt.resume_prefix.index('pth')
        #     opt.resume_prefix = opt.resume_prefix[0:trunc_index - 1]
        # checkpoint.save_prefix = os.path.join(exp_dir, opt.resume_prefix)
        checkpoint.save_prefix = exp_dir + '/'
        checkpoint.load_prefix = os.path.join(exp_dir, opt.resume_prefix)[0:-1]
        checkpoint.load_checkpoint(net, optimizer, train_history)
        opt.lr = optimizer.param_groups[0]['lr']
        resume_log = True
    else:
        checkpoint.save_prefix = exp_dir + '/'
        resume_log = False
    print 'save prefix: ', checkpoint.save_prefix
    # model = {'state_dict': net.state_dict()}
    # save_path = checkpoint.save_prefix + 'test-model-size.pth.tar'
    # torch.save(model, save_path)
    # exit()
    """load data"""
    train_loader = torch.utils.data.DataLoader(MPII(
        'dataset/mpii-hr-lsp-normalizer.json',
        '/bigdata1/zt53/data',
        is_train=True),
                                               batch_size=opt.bs,
                                               shuffle=True,
                                               num_workers=opt.nThreads,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(MPII(
        'dataset/mpii-hr-lsp-normalizer.json',
        '/bigdata1/zt53/data',
        is_train=False),
                                             batch_size=opt.bs,
                                             shuffle=False,
                                             num_workers=opt.nThreads,
                                             pin_memory=True)
    """optimizer"""
    # optimizer = torch.optim.SGD( net.parameters(), lr=opt.lr,
    #                             momentum=opt.momentum,
    #                             weight_decay=opt.weight_decay )
    # optimizer = torch.optim.RMSprop(net.parameters(), lr=opt.lr, alpha=0.99,
    #                                 eps=1e-8, momentum=0, weight_decay=0)
    print type(optimizer)
    # idx = range(0, 16)
    # idx = [e for e in idx if e not in (6, 7, 8, 9, 12, 13)]
    idx = [0, 1, 2, 3, 4, 5, 10, 11, 14, 15]
    logger = Logger(os.path.join(opt.exp_dir, opt.exp_id,
                                 'training-summary.txt'),
                    title='training-summary',
                    resume=resume_log)
    logger.set_names(
        ['Epoch', 'LR', 'Train Loss', 'Val Loss', 'Train Acc', 'Val Acc'])
    if not opt.is_train:
        visualizer.log_path = os.path.join(opt.exp_dir, opt.exp_id,
                                           'val_log.txt')
        val_loss, val_pckh, predictions = validate(
            val_loader, net, train_history.epoch[-1]['epoch'], visualizer, idx,
            joint_flip_index, num_classes)
        checkpoint.save_preds(predictions)
        return
    """training and validation"""
    start_epoch = 0
    if opt.resume_prefix != '':
        start_epoch = train_history.epoch[-1]['epoch'] + 1
    for epoch in range(start_epoch, opt.nEpochs):
        adjust_lr(opt, optimizer, epoch)
        # # train for one epoch
        train_loss, train_pckh = train(train_loader, net, optimizer, epoch,
                                       visualizer, idx, opt)

        # evaluate on validation set
        val_loss, val_pckh, predictions = validate(val_loader, net, epoch,
                                                   visualizer, idx,
                                                   joint_flip_index,
                                                   num_classes)
        # visualizer.display_imgpts(imgs, pred_pts, 4)
        # exit()
        # update training history
        e = OrderedDict([('epoch', epoch)])
        lr = OrderedDict([('lr', optimizer.param_groups[0]['lr'])])
        loss = OrderedDict([('train_loss', train_loss),
                            ('val_loss', val_loss)])
        pckh = OrderedDict([('val_pckh', val_pckh)])
        train_history.update(e, lr, loss, pckh)
        checkpoint.save_checkpoint(net, optimizer, train_history, predictions)
        # visualizer.plot_train_history(train_history)
        logger.append([
            epoch, optimizer.param_groups[0]['lr'], train_loss, val_loss,
            train_pckh, val_pckh
        ])
    logger.close()
Beispiel #23
0
def train():
    opt = TrainOptions().parse()
    if opt.debug:
        opt.display_freq = 1
        opt.print_freq = 1
        opt.nThreads = 1

    ### initialize dataset
    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('#training videos = %d' % dataset_size)

    ### initialize models
    models = create_model(opt)
    modelG, modelD, flowNet, optimizer_G, optimizer_D, optimizer_D_T = create_optimizer(
        opt, models)

    ### set parameters
    n_gpus, tG, tD, tDB, s_scales, t_scales, input_nc, output_nc, \
        start_epoch, epoch_iter, print_freq, total_steps, iter_path = init_params(opt, modelG, modelD, dataset_size)
    visualizer = Visualizer(opt)

    ### real training starts here
    for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
        for idx, data in enumerate(dataset, start=epoch_iter):
            if total_steps % print_freq == 0:
                iter_start_time = time.time()
            total_steps += opt.batchSize
            epoch_iter += opt.batchSize

            # whether to collect output images
            save_fake = total_steps % opt.display_freq == 0
            n_frames_total, n_frames_load, t_len = data_loader.dataset.init_data_params(
                data, n_gpus, tG)
            fake_B_last, frames_all = data_loader.dataset.init_data(t_scales)

            for i in range(0, n_frames_total, n_frames_load):
                input_A, input_B, inst_A = data_loader.dataset.prepare_data(
                    data, i, input_nc, output_nc)

                ###################################### Forward Pass ##########################
                ####### generator
                fake_B, fake_B_raw, flow, weight, real_A, real_Bp, fake_B_last = modelG(
                    input_A, input_B, inst_A, fake_B_last)

                ####### discriminator
                ### individual frame discriminator
                real_B_prev, real_B = real_Bp[:, :
                                              -1], real_Bp[:,
                                                           1:]  # the collection of previous and current real frames
                flow_ref, conf_ref = flowNet(
                    real_B, real_B_prev)  # reference flows and confidences
                #flow_ref, conf_ref = util.remove_dummy_from_tensor([flow_ref, conf_ref])
                fake_B_prev = modelG.module.compute_fake_B_prev(
                    real_B_prev, fake_B_last, fake_B)

                losses = modelD(
                    0,
                    reshape([
                        real_B, fake_B, fake_B_raw, real_A, real_B_prev,
                        fake_B_prev, flow, weight, flow_ref, conf_ref
                    ]))
                losses = [
                    torch.mean(x) if x is not None else 0 for x in losses
                ]
                loss_dict = dict(zip(modelD.module.loss_names, losses))

                ### temporal discriminator
                # get skipped frames for each temporal scale
                frames_all, frames_skipped = modelD.module.get_all_skipped_frames(frames_all, \
                        real_B, fake_B, flow_ref, conf_ref, t_scales, tD, n_frames_load, i, flowNet)

                # run discriminator for each temporal scale
                loss_dict_T = []
                for s in range(t_scales):
                    if frames_skipped[0][s] is not None:
                        losses = modelD(s + 1, [
                            frame_skipped[s]
                            for frame_skipped in frames_skipped
                        ])
                        losses = [
                            torch.mean(x) if not isinstance(x, int) else x
                            for x in losses
                        ]
                        loss_dict_T.append(
                            dict(zip(modelD.module.loss_names_T, losses)))

                # collect losses
                loss_G, loss_D, loss_D_T, t_scales_act = modelD.module.get_losses(
                    loss_dict, loss_dict_T, t_scales)

                ###################################### Backward Pass #################################
                # update generator weights
                loss_backward(opt, loss_G, optimizer_G)

                # update individual discriminator weights
                loss_backward(opt, loss_D, optimizer_D)

                # update temporal discriminator weights
                for s in range(t_scales_act):
                    loss_backward(opt, loss_D_T[s], optimizer_D_T[s])

                if i == 0:
                    fake_B_first = fake_B[
                        0, 0]  # the first generated image in this sequence

            if opt.debug:
                call([
                    "nvidia-smi", "--format=csv",
                    "--query-gpu=memory.used,memory.free"
                ])

            ############## Display results and errors ##########
            ### print out errors
            if total_steps % print_freq == 0:
                t = (time.time() - iter_start_time) / print_freq
                errors = {
                    k: v.data.item() if not isinstance(v, int) else v
                    for k, v in loss_dict.items()
                }
                for s in range(len(loss_dict_T)):
                    errors.update({
                        k + str(s):
                        v.data.item() if not isinstance(v, int) else v
                        for k, v in loss_dict_T[s].items()
                    })
                visualizer.print_current_errors(epoch, epoch_iter, errors, t)
                visualizer.plot_current_errors(errors, total_steps)

            ### display output images
            if save_fake:
                visuals = util.save_all_tensors(opt, real_A, fake_B,
                                                fake_B_first, fake_B_raw,
                                                real_B, flow_ref, conf_ref,
                                                flow, weight, modelD)
                visualizer.display_current_results(visuals, epoch, total_steps)

            ### save latest model
            save_models(opt, epoch, epoch_iter, total_steps, visualizer,
                        iter_path, modelG, modelD)
            if epoch_iter > dataset_size - opt.batchSize:
                epoch_iter = 0
                break

        # end of epoch
        iter_end_time = time.time()
        visualizer.vis_print('End of epoch %d / %d \t Time Taken: %d sec' %
                             (epoch, opt.niter + opt.niter_decay,
                              time.time() - epoch_start_time))

        ### save model for this epoch and update model params
        save_models(opt,
                    epoch,
                    epoch_iter,
                    total_steps,
                    visualizer,
                    iter_path,
                    modelG,
                    modelD,
                    end_of_epoch=True)
        update_models(opt, epoch, modelG, modelD, data_loader)
Beispiel #24
0
def train():
    opt = TrainOptions().parse()
    if opt.distributed:
        init_dist()
        opt.batchSize = opt.batchSize // len(opt.gpu_ids)

    ### setup dataset
    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()

    ### setup trainer
    trainer = Trainer(opt, data_loader)

    ### setup models
    model, flowNet = create_model(opt, trainer.start_epoch)
    flow_gt = conf_gt = [None] * 3

    ref_idx_fix = torch.zeros([opt.batchSize])
    for epoch in tqdm(
            range(trainer.start_epoch, opt.niter + opt.niter_decay + 1)):
        trainer.start_of_epoch(epoch, model, data_loader)
        n_frames_total, n_frames_load = data_loader.dataset.n_frames_total, opt.n_frames_per_gpu
        for idx, data in enumerate(tqdm(dataset), start=trainer.epoch_iter):
            trainer.start_of_iter()

            if not opt.warp_ani:
                data.update({
                    'ani_image': None,
                    'ani_lmark': None,
                    'cropped_images': None,
                    'cropped_lmarks': None
                })

            if not opt.no_flow_gt:
                data_list = [
                    data['tgt_mask_images'], data['cropped_images'],
                    data['warping_ref'], data['ani_image']
                ]
                flow_gt, conf_gt = flowNet(data_list, epoch)
            data_list = [
                data['tgt_label'], data['tgt_image'], data['tgt_template'],
                data['cropped_images'], flow_gt, conf_gt
            ]
            data_ref_list = [data['ref_label'], data['ref_image']]
            data_prev = [None, None, None]
            data_ani = [
                data['warping_ref_lmark'], data['warping_ref'],
                data['ori_warping_refs'], data['ani_lmark'], data['ani_image']
            ]

            ############## Forward Pass ######################
            prevs = {"raw_images":[], "synthesized_images":[], \
                    "prev_warp_images":[], "prev_weights":[], \
                    "ani_warp_images":[], "ani_weights":[], \
                    "ref_warp_images":[], "ref_weights":[], \
                    "ref_flows":[], "prev_flows":[], "ani_flows":[], \
                    "ani_syn":[]}
            for t in range(0, n_frames_total, n_frames_load):

                data_list_t = get_data_t(data_list, n_frames_load, t) + data_ref_list + \
                              get_data_t(data_ani, n_frames_load, t) + data_prev

                g_losses, generated, data_prev, ref_idx = model(
                    data_list_t,
                    save_images=trainer.save,
                    mode='generator',
                    ref_idx_fix=ref_idx_fix)
                g_losses = loss_backward(opt, g_losses,
                                         model.module.optimizer_G)

                d_losses, _ = model(data_list_t,
                                    mode='discriminator',
                                    ref_idx_fix=ref_idx_fix)
                d_losses = loss_backward(opt, d_losses,
                                         model.module.optimizer_D)

                # store previous
                store_prev(generated, prevs)

            loss_dict = dict(
                zip(model.module.lossCollector.loss_names,
                    g_losses + d_losses))

            output_data_list = [prevs] + [
                data['ref_image']
            ] + data_ani + data_list + [data['tgt_mask_images']]

            if trainer.end_of_iter(loss_dict, output_data_list, model):
                break

        trainer.end_of_epoch(model)
Beispiel #25
0
# This model trains for 100 epochs instead of the standard 50; and it uses the TensorBoard
# integration. KLD is still set to 10x the default.

import sys
sys.path.append('../lib/SPADE-master/')
from options.train_options import TrainOptions
from models.pix2pix_model import Pix2PixModel
from collections import OrderedDict
import data
from util.iter_counter import IterationCounter
from util.visualizer import Visualizer
from trainers.pix2pix_trainer import Pix2PixTrainer
import os

opt = TrainOptions()
opt.D_steps_per_G = 1
opt.aspect_ratio = 1.0
opt.batchSize = 1
opt.beta1 = 0.0
opt.beta2 = 0.9
opt.cache_filelist_read = False
opt.cache_filelist_write = False
opt.checkpoints_dir = '/spell/checkpoints/'
opt.contain_dontcare_label = False
opt.continue_train = False
opt.crop_size = 512
opt.dataroot = '/spell/bob_ross_segmented/'  # data mount point
opt.dataset_mode = 'custom'
opt.debug = False
opt.display_freq = 100
opt.display_winsize = 512
Beispiel #26
0
def main():
    # torch.manual_seed(1234)
    # torch.cuda.manual_seed(1234)
    opt = TrainOptions()
    args = opt.initialize()

    _t = {'iter time': Timer()}

    model_name = args.source + '_to_' + args.target
    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)
        os.makedirs(os.path.join(args.snapshot_dir, 'logs'))
    opt.print_options(args)

    sourceloader, targetloader = CreateSrcDataLoader(
        args), CreateTrgDataLoader(args)
    targetloader_iter, sourceloader_iter = iter(targetloader), iter(
        sourceloader)

    model, optimizer = CreateModel(args)
    model_D, optimizer_D = CreateDiscriminator(args)

    start_iter = 0
    if args.restore_from is not None:
        start_iter = int(args.restore_from.rsplit('/', 1)[1].rsplit('_')[1])

    train_writer = tensorboardX.SummaryWriter(
        os.path.join(args.snapshot_dir, "logs", model_name))

    bce_loss = torch.nn.BCEWithLogitsLoss()
    l1_loss = torch.nn.L1Loss()
    cos_loss = torch.nn.CosineSimilarity(dim=0, eps=1e-06)

    cudnn.enabled = True
    cudnn.benchmark = True
    model.train()
    model.cuda()
    model_D.train()
    model_D.cuda()
    loss = [
        'loss_seg_src', 'loss_seg_trg', 'loss_D_trg_fake', 'loss_D_src_real',
        'loss_D_trg_real'
    ]
    _t['iter time'].tic()

    pbar = tqdm(range(start_iter, args.num_steps_stop))
    #for i in range(start_iter, args.num_steps):
    for i in pbar:

        model.adjust_learning_rate(args, optimizer, i)
        model_D.adjust_learning_rate(args, optimizer_D, i)

        optimizer.zero_grad()
        optimizer_D.zero_grad()
        for param in model_D.parameters():
            param.requires_grad = False

        src_img, src_lbl, _, _ = sourceloader_iter.next()
        src_img, src_lbl = Variable(src_img).cuda(), Variable(
            src_lbl.long()).cuda()
        src_seg_score, src_seg_score2 = model(src_img)
        loss_seg_src1 = CrossEntropy2d(src_seg_score, src_lbl)
        loss_seg_src2 = CrossEntropy2d(src_seg_score2, src_lbl)
        loss_seg_src = loss_seg_src1 + loss_seg_src2
        loss_seg_src.backward()

        if args.data_label_folder_target is not None:
            trg_img, trg_lbl, _, _ = targetloader_iter.next()
            trg_img, trg_lbl = Variable(trg_img).cuda(), Variable(
                trg_lbl.long()).cuda()
            trg_seg_score = model(trg_img)
            loss_seg_trg = model.loss
        else:
            trg_img, _, name = targetloader_iter.next()
            trg_img = Variable(trg_img).cuda()
            trg_seg_score, trg_seg_score2 = model(trg_img)
            loss_seg_trg = 0

        outD_trg = model_D(F.softmax(trg_seg_score))
        outD_trg2 = model_D(F.softmax(trg_seg_score2))
        loss_D_trg_fake1 = bce_loss(
            outD_trg,
            Variable(torch.FloatTensor(outD_trg.data.size()).fill_(0)).cuda())
        loss_D_trg_fake2 = bce_loss(
            outD_trg2,
            Variable(torch.FloatTensor(outD_trg2.data.size()).fill_(0)).cuda())
        loss_D_trg_fake = loss_D_trg_fake1 + loss_D_trg_fake2

        loss_agree = l1_loss(F.softmax(trg_seg_score),
                             F.softmax(trg_seg_score2))

        loss_trg = args.lambda_adv_target * loss_D_trg_fake + loss_seg_trg + loss_agree
        loss_trg.backward()

        #Weight Discrepancy Loss

        # W5 = None
        # W6 = None
        # if args.model == 'DeepLab2':

        #     for (w5, w6) in zip(model.layer5.parameters(), model.layer6.parameters()):
        #         if W5 is None and W6 is None:
        #             W5 = w5.view(-1)
        #             W6 = w6.view(-1)
        #         else:
        #             W5 = torch.cat((W5, w5.view(-1)), 0)
        #             W6 = torch.cat((W6, w6.view(-1)), 0)

        #ipdb.set_trace()
        #loss_weight = (torch.matmul(W5, W6) / (torch.norm(W5) * torch.norm(W6)) + 1) # +1 is for a positive loss
        # loss_weight = loss_weight  * damping * 2
        #loss_weight = args.weight_div* (cos_loss(W5,W6) +1)
        #loss_weight.backward()
        loss_weight = torch.zeros(1)

        for param in model_D.parameters():
            param.requires_grad = True

        src_seg_score, trg_seg_score = src_seg_score.detach(
        ), trg_seg_score.detach()
        src_seg_score2, trg_seg_score2 = src_seg_score2.detach(
        ), trg_seg_score2.detach()

        outD_src = model_D(F.softmax(src_seg_score))
        loss_D_src_real1 = bce_loss(
            outD_src,
            Variable(torch.FloatTensor(
                outD_src.data.size()).fill_(0)).cuda()) / 2
        outD_src2 = model_D(F.softmax(src_seg_score2))
        loss_D_src_real2 = bce_loss(
            outD_src2,
            Variable(torch.FloatTensor(
                outD_src2.data.size()).fill_(0)).cuda()) / 2
        loss_D_src_real = loss_D_src_real1 + loss_D_src_real2
        loss_D_src_real.backward()

        outD_trg = model_D(F.softmax(trg_seg_score))
        loss_D_trg_real1 = bce_loss(
            outD_trg,
            Variable(torch.FloatTensor(
                outD_trg.data.size()).fill_(1)).cuda()) / 2
        outD_trg2 = model_D(F.softmax(trg_seg_score2))
        loss_D_trg_real2 = bce_loss(
            outD_trg2,
            Variable(torch.FloatTensor(
                outD_trg2.data.size()).fill_(1)).cuda()) / 2
        loss_D_trg_real = loss_D_trg_real1 + loss_D_trg_real2
        loss_D_trg_real.backward()

        d_loss = loss_D_src_real.data + loss_D_trg_real.data

        optimizer.step()
        optimizer_D.step()

        for m in loss:
            train_writer.add_scalar(m, eval(m), i + 1)

        if (i + 1) % args.save_pred_every == 0:
            print 'taking snapshot ...'
            torch.save(
                model.state_dict(),
                os.path.join(args.snapshot_dir,
                             '%s_' % (args.source) + str(i + 1) + '.pth'))
            torch.save(
                model_D.state_dict(),
                os.path.join(args.snapshot_dir,
                             '%s_' % (args.source) + str(i + 1) + '_D.pth'))

        if (i + 1) % args.print_freq == 0:
            _t['iter time'].toc(average=False)
            print '[it %d][src seg loss %.4f][adv loss %.4f][d loss %.4f][agree loss %.4f][div loss %.4f][lr %.4f][%.2fs]' % \
                    (i + 1, loss_seg_src.data, loss_D_trg_fake.data,d_loss,loss_agree.data,loss_weight.data, optimizer.param_groups[0]['lr']*10000, _t['iter time'].diff)
            if i + 1 > args.num_steps_stop:
                print 'finish training'
                break
            _t['iter time'].tic()
Beispiel #27
0
import torch
import sys
from torch.autograd import Variable
import numpy as np
from options.train_options import TrainOptions
opt = TrainOptions().parse()  # set CUDA_VISIBLE_DEVICES before import torch
from data.data_loader import CreateDataLoader
from models.models import create_model
from skimage import io
from skimage.transform import resize

img_path = 'demo.jpg'

model = create_model(opt)

input_height = 384
input_width = 512


def test_simple(model):
    total_loss = 0
    toal_count = 0
    print("============================= TEST ============================")
    model.switch_to_eval()

    img = np.float32(io.imread(img_path)) / 255.0
    img = resize(img, (input_height, input_width), order=1)
    input_img = torch.from_numpy(np.transpose(img,
                                              (2, 0, 1))).contiguous().float()
    input_img = input_img.unsqueeze(0)
Beispiel #28
0
def main():

    opt = TrainOptions()
    args = opt.initialize()

    _t = {'iter time': Timer()}

    model_name = args.source + '_to_' + args.target
    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)
        os.makedirs(os.path.join(args.snapshot_dir, 'logs'))
    opt.print_options(args)

    sourceloader, targetloader = CreateSrcDataLoader(
        args), CreateTrgDataLoader(args)
    targetloader_iter, sourceloader_iter = iter(targetloader), iter(
        sourceloader)

    model, optimizer = CreateModel(args)
    model_D, optimizer_D = CreateDiscriminator(args)

    start_iter = 0
    if args.restore_from is not None:
        start_iter = int(args.restore_from.rsplit('/', 1)[1].rsplit('_')[1])

    train_writer = tensorboardX.SummaryWriter(
        os.path.join(args.snapshot_dir, "logs", model_name))

    bce_loss = torch.nn.BCEWithLogitsLoss()

    cudnn.enabled = True
    cudnn.benchmark = True
    model.train()
    model.cuda()
    model_D.train()
    model_D.cuda()
    loss = [
        'loss_seg_src', 'loss_seg_trg', 'loss_D_trg_fake', 'loss_D_src_real',
        'loss_D_trg_real'
    ]
    _t['iter time'].tic()
    for i in range(start_iter, args.num_steps):

        model.adjust_learning_rate(args, optimizer, i)
        model_D.adjust_learning_rate(args, optimizer_D, i)

        optimizer.zero_grad()
        optimizer_D.zero_grad()
        for param in model_D.parameters():
            param.requires_grad = False

        src_img, src_lbl, _, _ = sourceloader_iter.next()
        src_img, src_lbl = Variable(src_img).cuda(), Variable(
            src_lbl.long()).cuda()
        src_seg_score = model(src_img, lbl=src_lbl)
        loss_seg_src = model.loss
        loss_seg_src.backward()

        if args.data_label_folder_target is not None:
            trg_img, trg_lbl, _, _ = targetloader_iter.next()
            trg_img, trg_lbl = Variable(trg_img).cuda(), Variable(
                trg_lbl.long()).cuda()
            trg_seg_score = model(trg_img, lbl=trg_lbl)
            loss_seg_trg = model.loss
        else:
            trg_img, _, name = targetloader_iter.next()
            trg_img = Variable(trg_img).cuda()
            trg_seg_score = model(trg_img)
            loss_seg_trg = 0

        outD_trg = model_D(F.softmax(trg_seg_score), 0)
        loss_D_trg_fake = model_D.loss

        loss_trg = args.lambda_adv_target * loss_D_trg_fake + loss_seg_trg
        loss_trg.backward()

        for param in model_D.parameters():
            param.requires_grad = True

        src_seg_score, trg_seg_score = src_seg_score.detach(
        ), trg_seg_score.detach()

        outD_src = model_D(F.softmax(src_seg_score), 0)
        loss_D_src_real = model_D.loss / 2
        loss_D_src_real.backward()

        outD_trg = model_D(F.softmax(trg_seg_score), 1)
        loss_D_trg_real = model_D.loss / 2
        loss_D_trg_real.backward()

        optimizer.step()
        optimizer_D.step()

        for m in loss:
            train_writer.add_scalar(m, eval(m), i + 1)

        if (i + 1) % args.save_pred_every == 0:
            print 'taking snapshot ...'
            torch.save(
                model.state_dict(),
                os.path.join(args.snapshot_dir,
                             '%s_' % (args.source) + str(i + 1) + '.pth'))

        if (i + 1) % args.print_freq == 0:
            _t['iter time'].toc(average=False)
            print '[it %d][src seg loss %.4f][lr %.4f][%.2fs]' % \
                    (i + 1, loss_seg_src.data, optimizer.param_groups[0]['lr']*10000, _t['iter time'].diff)
            if i + 1 > args.num_steps_stop:
                print 'finish training'
                break
            _t['iter time'].tic()
You need to specify the experiment name ('--name'), and game ('--game').

<todo>
It first creates model, dataset, and visualizer given the option.
It then does standard network training. During the training, it also visualize/save the images, print/save the loss plot, and save models.
The script supports continue/resume training. Use '--continue_train' to resume your previous training.
Example:
    Train a model for the game of Go:
        python train.py --name alphazero_go
    or
        python train.py --name alphazero_go --game go --model resnet

See options/base_options.py and options/train_options.py for more training options.
"""

from options.train_options import TrainOptions
from games import create_game
from models import create_model
from mcts import MonteTree
from coach import Coacher

if __name__ == '__main__':
    opt = TrainOptions().parse()   # get training options
    game = create_game(opt)
    model = create_model(opt, game)   
    admodels = [create_model(opt, game) for __ in range(opt.num_opp)]   
    montetree = MonteTree(opt, game, model)
    coacher = Coacher(opt, game, model, montetree, advmodels)
    coacher.learn()

Beispiel #30
0
def train_main(raw_args=None):
    # print(torch.backends.cudnn.benchmark)
    opt = TrainOptions().parse(raw_args)  # get training options
    if opt.debug_mode:
        import multiprocessing
        multiprocessing.set_start_method('spawn', True)
        opt.num_threads = 0

    dataset = create_dataset(
        opt)  # create a dataset given opt.dataset_mode and other options
    dataset_size = len(dataset)  # get the number of images in the dataset.
    print('The number of training images = %d' % dataset_size)

    existing_epochs = glob.glob(opt.checkpoints_dir + '/' + opt.name +
                                '/*[0-9]_net_G_A.pth')
    if opt.restart_training and len(existing_epochs) > 0:
        opt.epoch = int(
            os.path.splitext(os.path.basename(
                existing_epochs[-1]))[0].split('_')[0])
        opt.epoch_count = opt.epoch + 1

    plot_losses_from_log_files(opt,
                               dataset_size,
                               domain=['A', 'B'],
                               specified=['G', 'D', 'cycle'])

    model = create_model(
        opt)  # create a model given opt.model and other options
    model.setup(
        opt)  # regular setup: load and print networks; create schedulers
    visualizer = Visualizer(
        opt)  # create a visualizer that display/save images and plots
    total_iters = 0  # the total number of training iterations

    for epoch in range(
            opt.epoch_count, opt.niter + opt.niter_decay + 1
    ):  # outer loop for different epochs; we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>
        epoch_start_time = time.time()  # timer for entire epoch
        iter_data_time = time.time()  # timer for data loading per iteration
        epoch_iter = 0  # the number of training iterations in current epoch, reset to 0 every epoch

        for i, data in enumerate(dataset):  # inner loop within one epoch
            iter_start_time = time.time(
            )  # timer for computation per iteration
            if total_iters % opt.print_freq == 0:
                t_data = iter_start_time - iter_data_time
            visualizer.reset()
            total_iters += opt.batch_size
            epoch_iter += opt.batch_size
            model.set_input(
                data)  # unpack data from dataset and apply preprocessing
            model.optimize_parameters(
            )  # calculate loss functions, get gradients, update network weights

            if total_iters % opt.display_freq == 0:  # display images on visdom and save images to a HTML file
                save_result = total_iters % opt.update_html_freq == 0
                model.compute_visuals()
                visualizer.display_current_results(model.get_current_visuals(),
                                                   epoch, save_result)

            if total_iters % opt.print_freq == 0:  # print training losses and save logging information to the disk
                losses = model.get_current_losses()
                t_comp = (time.time() - iter_start_time) / opt.batch_size
                visualizer.print_current_losses(epoch, epoch_iter, losses,
                                                t_comp, t_data)
                if opt.display_id > 0:
                    visualizer.plot_current_losses(
                        epoch,
                        float(epoch_iter) / dataset_size, losses)

            if total_iters % opt.save_latest_freq == 0:  # cache our latest model every <save_latest_freq> iterations
                print('saving the latest model (epoch %d, total_iters %d)' %
                      (epoch, total_iters))
                save_suffix = 'iter_%d' % total_iters if opt.save_by_iter else 'latest'
                model.save_networks(save_suffix)

            iter_data_time = time.time()
        if epoch % opt.save_epoch_freq == 0:  # cache our model every <save_epoch_freq> epochs
            print('saving the model at the end of epoch %d, iters %d' %
                  (epoch, total_iters))
            model.save_networks('latest')
            model.save_networks(epoch)

        print('End of epoch %d / %d \t Time Taken: %d sec' %
              (epoch, opt.niter + opt.niter_decay,
               time.time() - epoch_start_time))
        model.update_learning_rate(
        )  # update learning rates at the end of every epoch.