Пример #1
0
def setup(model_dir):
    global model
    opt = {}
    opt = SimpleNamespace(**opt)
    opt.nThreads = 1
    opt.batchSize = 1
    opt.serial_batches = True
    opt.no_flip = True 
    opt.name = 'pretrained'
    opt.checkpoints_dir = '.'
    opt.model = 'pix2pix'
    opt.which_direction = 'AtoB'
    opt.norm = 'batch'
    opt.input_nc = 3
    opt.output_nc = 1
    opt.which_model_netG = 'resnet_9blocks'
    opt.no_dropout = True
    opt.isTrain = False
    opt.use_cuda = True
    opt.ngf = 64
    opt.ndf = 64
    opt.init_type = 'normal'
    opt.which_epoch = 'latest'
    opt.pretrain_path = model_dir
    model = create_model(opt)
    return model
Пример #2
0
opt = TrainOptions().parse()
n_frames_G, n_frames_D = opt.n_frames_G, opt.n_frames_D
t_scales = opt.n_scales_temporal
input_nc = opt.input_nc
output_nc = opt.output_nc

visualizer = Visualizer(opt)

### initialize dataset
data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()
dataset_size = len(data_loader)

### initialize models
modelG, modelD, flowNet = create_model(opt)

iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')
if opt.continue_train:
    try:
        start_epoch, epoch_iter = np.loadtxt(iter_path,
                                             delimiter=',',
                                             dtype=int)
    except:
        start_epoch, epoch_iter = 1, 0
    if epoch_iter > 0:
        ### initialize dataset again
        if opt.serial_batches:
            data_loader = CreateDataLoader(opt, epoch_iter)
            dataset = data_loader.load_data()
    visualizer.vis_print('Resuming from epoch %d at iteration %d' %
Пример #3
0
def train():
    '''load data'''
    transform_mask = transforms.Compose([
        # transforms.Resize((opt.fineSize,opt.fineSize)),
        transforms.ToTensor(),
    ])
    transform = transforms.Compose([
        # transforms.RandomHorizontalFlip(),
        # transforms.Resize((opt.fineSize,opt.fineSize)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5] * 3, std=[0.5] * 3)
    ])

    dataset_train = Data_load(opt.dataroot, opt.maskroot, transform,
                              transform_mask)
    iterator_train = (data.DataLoader(dataset_train,
                                      batch_size=opt.batchSize,
                                      shuffle=True))

    dataset_test = Data_load(opt.valinputroot, opt.valmask, transform,
                             transform_mask)
    iterator_test = (data.DataLoader(dataset_test,
                                     batch_size=opt.batchSize,
                                     shuffle=False))

    print(len(dataset_train))
    print(len(dataset_test))
    model = create_model(opt)
    total_steps = 0

    if not os.path.exists(opt.checkpoints_dir):
        os.makedirs(opt.checkpoints_dir)
    '''if resuming'''
    # model.load(epoch_number/best)

    best_score = 0
    gts = []
    preds = []
    for i in range(100):
        img_name = "{}.jpg".format(401 + i)
        gts.append(os.path.join(opt.valroot, img_name))
        pred_name = "{}.jpg".format(401 + i)
        preds.append(os.path.join("pred", pred_name))

    for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
        # for epoch in range(loaded_epoch_#+1, opt.niter + opt.niter_decay + 1):
        print("this is a new epoch!")
        epoch_start_time = time.time()
        epoch_iter = 0

        for ori_image, ori_mask in iterator_train:
            align_corners = True
            img_size = 512  # Or any other size

            # Shrinking the image to a square. # ori_imgs is the original image
            image = F.interpolate(ori_image,
                                  img_size,
                                  mode='bicubic',
                                  align_corners=align_corners)
            mask = F.interpolate(ori_mask,
                                 img_size,
                                 mode='bicubic',
                                 align_corners=align_corners
                                 )  # mask looks pretty complete at this stage

            image = image.clamp(min=-1, max=1)
            mask = (mask > 0).type(torch.FloatTensor)

            image = image.cuda()
            mask = mask.cuda()
            mask = mask[0][0]
            mask = torch.unsqueeze(mask, 0)
            mask = torch.unsqueeze(mask, 1)
            mask = mask.byte()

            total_steps += opt.batchSize
            epoch_iter += opt.batchSize
            model.set_input(
                image, mask)  # sets both input data with mask and latent mask.
            model.set_gt_latent()
            model.optimize_parameters()

            if total_steps % opt.display_freq == 0:
                real_A, real_B, fake_B = model.get_current_visuals()
                # real_A=input, real_B=ground truth fake_b=output
                pic = (torch.cat([real_A, real_B, fake_B], dim=0) + 1) / 2.0
                print("saving image!")
                torchvision.utils.save_image(
                    pic,
                    '%s/Epoch_(%d)_(%dof%d).jpg' %
                    (opt.checkpoints_dir, epoch, total_steps + 1,
                     len(dataset_train)),
                    nrow=2)

                if total_steps % 1 == 0:
                    errors = model.get_current_errors()
                    print(errors)

        if epoch % opt.save_epoch_freq == 0:
            print('saving the model at the end of epoch %d, iters %d' %
                  (epoch, total_steps))
            model.save(epoch)

        # evaluation #
        generate_images(model, iterator_test)
        mse_avg, ssim_avg = get_average_mse_ssim(gts, preds)
        score = 1 - mse_avg / 100 + ssim_avg

        if score > best_score:
            best_score = score
            model.save("best")
            print(
                f"best_score {best_score} found in epoch {epoch}. MSE:{mse_avg} SSIM:{ssim_avg}"
            )

        ### end evaluation ###

        print('End of epoch %d / %d \t Time Taken: %d sec' %
              (epoch, opt.niter + opt.niter_decay,
               time.time() - epoch_start_time))

        model.update_learning_rate()
def test(opt):
    opt.nThreads = 1   # test code only supports nThreads = 1
    opt.batchSize = 1  # test code only supports batchSize = 1
    opt.serial_batches = True  # no shuffle
    opt.no_flip = True  # no flip

    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('#test batches = %d' % (int(dataset_size / len(opt.sort_order))))
    visualizer = Visualizer(opt)
    model = create_model(opt)
    model.eval()

    # create webpage
    if opt.random_seed != -1:
        exp_dir = '%s_%s_seed%s' % (opt.phase, opt.which_epoch, str(opt.random_seed))
    else:
        exp_dir = '%s_%s' % (opt.phase, opt.which_epoch)
    web_dir = os.path.join(opt.results_dir, opt.name, exp_dir)

    if opt.traverse or opt.deploy:
        if opt.traverse:
            out_dirname = 'traversal'
        else:
            out_dirname = 'deploy'
        output_dir = os.path.join(web_dir,out_dirname)
        if not os.path.isdir(output_dir):
            os.makedirs(output_dir)

        for image_path in opt.image_path_list:
            print(image_path)
            data = dataset.dataset.get_item_from_path(image_path)
            visuals = model.inference(data)
            if opt.traverse and opt.make_video:
                out_path = os.path.join(output_dir, os.path.splitext(os.path.basename(image_path))[0] + '.mp4')
                visualizer.make_video(visuals, out_path)
            elif opt.traverse or (opt.deploy and opt.full_progression):
                if opt.traverse and opt.compare_to_trained_outputs:
                    out_path = os.path.join(output_dir, os.path.splitext(os.path.basename(image_path))[0] + '_compare_to_{}_jump_{}.png'.format(opt.compare_to_trained_class, opt.trained_class_jump))
                else:
                    out_path = os.path.join(output_dir, os.path.splitext(os.path.basename(image_path))[0] + '.png')
                visualizer.save_row_image(visuals, out_path, traverse=opt.traverse)
            else:
                out_path = os.path.join(output_dir, os.path.basename(image_path[:-4]))
                visualizer.save_images_deploy(visuals, out_path)
    else:
        webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch))

        # test
        for i, data in enumerate(dataset):
            if i >= opt.how_many:
                break

            visuals = model.inference(data)
            img_path = data['Paths']
            rem_ind = []
            for i, path in enumerate(img_path):
                if path != '':
                    print('process image... %s' % path)
                else:
                    rem_ind += [i]

            for ind in reversed(rem_ind):
                del img_path[ind]

            visualizer.save_images(webpage, visuals, img_path)

            webpage.save()
Пример #5
0
train_list_dir_portrait = root + '/phoenix/S6/zl548/MegaDpeth_code/train_list/portrait/'
input_height = 320
input_width = 240
data_loader_p = CreateDataLoader(root, train_list_dir_portrait, input_height,
                                 input_width)
dataset_p = data_loader_p.load_data()
dataset_size_p = len(data_loader_p)
print('========================= training portrait  images = %d' %
      dataset_size_p)

_isTrain = False
batch_size = 32
num_iterations_L = (dataset_size_l) / batch_size
num_iterations_P = (dataset_size_p) / batch_size
model = create_model(opt, _isTrain)
model.switch_to_train()

best_loss = 100

print("num_iterations ", num_iterations_L, num_iterations_P)

torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
best_epoch = 0
total_iteration = 0

print(
    "=================================  BEGIN TRAINING ====================================="
)
Пример #6
0
    start_epoch, epoch_iter = 1, 0

if opt.debug:
    opt.display_freq = 1
    opt.print_freq = 1
    opt.niter = 1
    opt.niter_decay = 0
    opt.max_dataset_size = 10

data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()
dataset_size = len(data_loader)
print('#training images = %d' % dataset_size)
num_batches = int(dataset_size / opt.batchSize)

model = create_model(opt, num_batches)
visualizer = Visualizer(opt)

total_steps = (start_epoch - 1) * dataset_size + epoch_iter

for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
    epoch_start_time = time.time()
    if epoch != start_epoch:
        epoch_iter = epoch_iter % num_batches
    iter_start = epoch_iter
    for i, data in enumerate(dataset, start=epoch_iter):
        iter_start_time = time.time()
        total_steps += opt.batchSize
        epoch_iter += 1  #opt.batchSize
        ##################
        ## Forward Pass ##
Пример #7
0
elif opt.dataset == 'scannet':
    eval_list_path = root + '/phoenix/S3/zl548/ScanNet/test_scannet_normal_list.txt'
    eval_num_threads = 2
    test_data_loader = CreateScanNetDataLoader(opt, eval_list_path, False,
                                               EVAL_BATCH_SIZE,
                                               eval_num_threads)
    test_dataset = test_data_loader.load_data()
    test_data_size = len(test_data_loader)
    print('========================= ScanNet eval #images = %d =========' %
          test_data_size)

else:
    print('INPUT DATASET DOES NOT EXIST!!!')
    sys.exit()

model = create_model(opt, _isTrain=False)
model.switch_to_train()

torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
global_step = 0


def test_numerical(model, dataset, global_step):
    rot_e_list = []
    roll_e_list = []
    pitch_e_list = []

    count = 0.0

    model.switch_to_eval()
def train():
    opt = TrainOptions().parse()    
    if opt.distributed:
        init_dist()
        opt.batchSize = opt.batchSize // len(opt.gpu_ids)    

    ### setup dataset
    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()

    ### setup trainer    
    trainer = Trainer(opt, data_loader) 

    ### setup models
    model, flowNet = create_model(opt, trainer.start_epoch)
    flow_gt = conf_gt = [None] * 3      
    
    ref_idx_fix = torch.zeros([opt.batchSize])
    for epoch in tqdm(range(trainer.start_epoch, opt.niter + opt.niter_decay + 1)):
        trainer.start_of_epoch(epoch, model, data_loader)
        n_frames_total, n_frames_load = data_loader.dataset.n_frames_total, opt.n_frames_per_gpu
        for idx, data in enumerate(tqdm(dataset), start=trainer.epoch_iter):
            trainer.start_of_iter()            

            if not opt.warp_ani:
                data.update({'ani_image':None, 'ani_lmark':None, 'cropped_images':None, 'cropped_lmarks':None })

            if not opt.no_flow_gt: 
                data_list = [data['tgt_mask_images'], data['cropped_images'], data['warping_ref'], data['ani_image']]
                flow_gt, conf_gt = flowNet(data_list, epoch)
            data_list = [data['tgt_label'], data['tgt_image'], data['tgt_template'], data['cropped_images'], flow_gt, conf_gt]
            data_ref_list = [data['ref_label'], data['ref_image']]
            data_prev = [None, None, None]
            data_ani = [data['warping_ref_lmark'], data['warping_ref'], data['ori_warping_refs'], data['ani_lmark'], data['ani_image']]

            ############## Forward Pass ######################
            prevs = {"raw_images":[], "synthesized_images":[], \
                    "prev_warp_images":[], "prev_weights":[], \
                    "ani_warp_images":[], "ani_weights":[], \
                    "ref_warp_images":[], "ref_weights":[], \
                    "ref_flows":[], "prev_flows":[], "ani_flows":[], \
                    "ani_syn":[]}
            for t in range(0, n_frames_total, n_frames_load):
                
                data_list_t = get_data_t(data_list, n_frames_load, t) + data_ref_list + \
                              get_data_t(data_ani, n_frames_load, t) + data_prev

                g_losses, generated, data_prev, ref_idx = model(data_list_t, save_images=trainer.save, mode='generator', ref_idx_fix=ref_idx_fix)
                g_losses = loss_backward(opt, g_losses, model.module.optimizer_G)

                d_losses, _ = model(data_list_t, mode='discriminator', ref_idx_fix=ref_idx_fix)
                d_losses = loss_backward(opt, d_losses, model.module.optimizer_D)

                # store previous
                store_prev(generated, prevs)
                        
            loss_dict = dict(zip(model.module.lossCollector.loss_names, g_losses + d_losses))     

            output_data_list = [prevs] + [data['ref_image']] + data_ani + data_list + [data['tgt_mask_images']]

            if trainer.end_of_iter(loss_dict, output_data_list, model):
                break        

            # pdb.set_trace()

        trainer.end_of_epoch(model)
Пример #9
0
def train(opt):
    iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')

    if opt.continue_train:
        if opt.which_epoch == 'latest':
            try:
                start_epoch, epoch_iter = np.loadtxt(iter_path,
                                                     delimiter=',',
                                                     dtype=int)
            except:
                start_epoch, epoch_iter = 1, 0
        else:
            start_epoch, epoch_iter = int(opt.which_epoch), 0

        print('Resuming from epoch %d at iteration %d' %
              (start_epoch, epoch_iter))
        for update_point in opt.decay_epochs:
            if start_epoch < update_point:
                break

            opt.lr *= opt.decay_gamma
    else:
        start_epoch, epoch_iter = 0, 0

    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('#training images = %d' % dataset_size)

    model = create_model(opt)
    visualizer = Visualizer(opt)

    total_steps = (start_epoch) * dataset_size + epoch_iter

    display_delta = total_steps % opt.display_freq
    print_delta = total_steps % opt.print_freq
    save_delta = total_steps % opt.save_latest_freq
    bSize = opt.batchSize

    #in case there's no display sample one image from each class to test after every epoch
    if opt.display_id == 0:
        dataset.dataset.set_sample_mode(True)
        dataset.num_workers = 1
        for i, data in enumerate(dataset):
            if i * opt.batchSize >= opt.numClasses:
                break
            if i == 0:
                sample_data = data
            else:
                for key, value in data.items():
                    if torch.is_tensor(data[key]):
                        sample_data[key] = torch.cat(
                            (sample_data[key], data[key]), 0)
                    else:
                        sample_data[key] = sample_data[key] + data[key]
        dataset.num_workers = opt.nThreads
        dataset.dataset.set_sample_mode(False)

    for epoch in range(start_epoch, opt.epochs):
        epoch_start_time = time.time()
        if epoch != start_epoch:
            epoch_iter = 0
        for i, data in enumerate(dataset, start=epoch_iter):
            iter_start_time = time.time()
            total_steps += opt.batchSize
            epoch_iter += opt.batchSize

            # whether to collect output images
            save_fake = (total_steps % opt.display_freq
                         == display_delta) and (opt.display_id > 0)

            ############## Network Pass ########################
            model.set_inputs(data)
            disc_losses = model.update_D()
            gen_losses, gen_in, gen_out, rec_out, cyc_out = model.update_G(
                infer=save_fake)
            loss_dict = dict(gen_losses, **disc_losses)
            ##################################################

            ############## Display results and errors ##########
            ### print out errors
            if total_steps % opt.print_freq == print_delta:
                errors = {
                    k: v.item()
                    if not (isinstance(v, float) or isinstance(v, int)) else v
                    for k, v in loss_dict.items()
                }
                t = (time.time() - iter_start_time) / opt.batchSize
                visualizer.print_current_errors(epoch + 1, epoch_iter, errors,
                                                t)
                if opt.display_id > 0:
                    visualizer.plot_current_errors(
                        epoch,
                        float(epoch_iter) / dataset_size, opt, errors)

            ### display output images
            if save_fake and opt.display_id > 0:
                class_a_suffix = ' class {}'.format(data['A_class'][0])
                class_b_suffix = ' class {}'.format(data['B_class'][0])
                classes = None

                visuals = OrderedDict()
                visuals_A = OrderedDict([('real image' + class_a_suffix,
                                          util.tensor2im(gen_in.data[0]))])
                visuals_B = OrderedDict([('real image' + class_b_suffix,
                                          util.tensor2im(gen_in.data[bSize]))])

                A_out_vis = OrderedDict([('synthesized image' + class_b_suffix,
                                          util.tensor2im(gen_out.data[0]))])
                B_out_vis = OrderedDict([('synthesized image' + class_a_suffix,
                                          util.tensor2im(gen_out.data[bSize]))
                                         ])
                if opt.lambda_rec > 0:
                    A_out_vis.update([('reconstructed image' + class_a_suffix,
                                       util.tensor2im(rec_out.data[0]))])
                    B_out_vis.update([('reconstructed image' + class_b_suffix,
                                       util.tensor2im(rec_out.data[bSize]))])
                if opt.lambda_cyc > 0:
                    A_out_vis.update([('cycled image' + class_a_suffix,
                                       util.tensor2im(cyc_out.data[0]))])
                    B_out_vis.update([('cycled image' + class_b_suffix,
                                       util.tensor2im(cyc_out.data[bSize]))])

                visuals_A.update(A_out_vis)
                visuals_B.update(B_out_vis)
                visuals.update(visuals_A)
                visuals.update(visuals_B)

                ncols = len(visuals_A)
                visualizer.display_current_results(visuals, epoch, classes,
                                                   ncols)

            ### save latest model
            if total_steps % opt.save_latest_freq == save_delta:
                print('saving the latest model (epoch %d, total_steps %d)' %
                      (epoch + 1, total_steps))
                model.save('latest')
                np.savetxt(iter_path, (epoch, epoch_iter),
                           delimiter=',',
                           fmt='%d')
                if opt.display_id == 0:
                    model.eval()
                    visuals = model.inference(sample_data)
                    visualizer.save_matrix_image(visuals, 'latest')
                    model.train()

        # end of epoch
        iter_end_time = time.time()
        print('End of epoch %d / %d \t Time Taken: %d sec' %
              (epoch + 1, opt.epochs, time.time() - epoch_start_time))

        ### save model for this epoch
        if (epoch + 1) % opt.save_epoch_freq == 0:
            print('saving the model at the end of epoch %d, iters %d' %
                  (epoch + 1, total_steps))
            model.save('latest')
            model.save(epoch + 1)
            np.savetxt(iter_path, (epoch + 1, 0), delimiter=',', fmt='%d')
            if opt.display_id == 0:
                model.eval()
                visuals = model.inference(sample_data)
                visualizer.save_matrix_image(visuals, epoch + 1)
                model.train()

        ### multiply learning rate by opt.decay_gamma after certain iterations
        if (epoch + 1) in opt.decay_epochs:
            model.update_learning_rate()
Пример #10
0
from data.data_loader import CreateDataLoader
from models.models import create_model
import utils.util as util
from utils.visualizer import Visualizer
from utils import html
from torch.autograd import Variable
import time

opt = TestOptions().parse(save=False)
opt.nThreads = 1
opt.batchSize = 1
opt.serial_batches = True  # no shuffle

data_loader = CreateDataLoader(opt)
dataset, _ = data_loader.load_data()
model = create_model(opt, data_loader.dataset)
visualizer = Visualizer(opt)
# create website
web_dir = os.path.join(opt.results_dir, opt.name,
                       '%s_%s' % (opt.phase, opt.which_epoch))
webpage = html.HTML(web_dir, '%s: %s' % (opt.name, opt.which_epoch))

img_dir = os.path.join(web_dir, 'images')
# test

label_trues, label_preds = [], []

model.model.eval()
tic = time.time()

accs = []
Пример #11
0
    D_mesh = D_mesh.unsqueeze_(0).expand(batchSize, dpt, hgt, wdt)
    H_mesh = H_mesh.unsqueeze_(0).expand(batchSize, dpt, hgt, wdt)
    W_mesh = W_mesh.unsqueeze_(0).expand(batchSize, dpt, hgt, wdt)
    D_upmesh = dDepth.float() + D_mesh
    H_upmesh = dHeight.float() + H_mesh
    W_upmesh = dWidth.float() + W_mesh
    return torch.stack([D_upmesh, H_upmesh, W_upmesh], dim=1)


if __name__ == '__main__':
    opt = TestOptions().parse()

    opt.nThreads = 1
    opt.batchSize = 1
    model_regist = create_model(opt)
    visualizer = Visualizer(opt)
    stn = Dense3DSpatialTransformer()

    # create website
    web_dir = os.path.join(opt.results_dir, opt.name,
                           '%s_%s' % (opt.phase, opt.which_epoch))
    webpage = html.HTML(
        web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' %
        (opt.name, opt.phase, opt.which_epoch))

    datafiles = []
    dataFiles = sorted(os.listdir(opt.dataroot))
    for isub, dataName in enumerate(dataFiles):
        datafiles.append(os.path.join(opt.dataroot, dataName))
Пример #12
0
def train(config, writer, logger):
    config = config.opt

    config.distributed = False

    if 'WORLD_SIZE' in os.environ:
        config.distributed = int(os.environ['WORLD_SIZE']) > 1

    config.gpu = 0
    config.world_size = 1

    if config.distributed:
        config.gpu = config.local_rank
        torch.cuda.set_device(config.gpu)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        config.world_size = torch.distributed.get_world_size()

    data_set = CreateDataLoader(config).load_data()
    total_steps = config.epochs * len(data_set)
    if config.gpu == 0:
        print(len(data_set), "# of Training Images")

    model = create_model(config)
    visualizer = Visualizer(config)

    model.named_buffers = lambda: []

    average_tensor = utils.load_average_img(config)
    average_tensor = average_tensor.view(1, *average_tensor.shape).cuda()

    if config.fp16:
        from apex import amp
        from apex.parallel import DistributedDataParallel as DDP
        model = model.cuda()
        model, [optimizer_G, optimizer_D] = \
            amp.initialize(model, [model.optimizer_G, model.optimizer_D],
                           opt_level='O1')

        if config.distributed:
            model = DDP(model, delay_allreduce=True)
        else:
            model = torch.nn.DataParallel(model)

    step = 0

    for epoch in range(config.epochs):
        if config.gpu == 0:
            print("epoch: ", epoch)
        for i, data in enumerate(data_set):
            save_gen = (step + 1) % config.display_freq == 0

            if i == 0 or (not config.no_temporal_smoothing and \
                          np.random.random() < config.prob_restart_sequence):
                if config.no_temporal_smoothing:
                    prev_generated = prev_label = prev_real = None
                else:
                    prev_generated = average_tensor
                    prev_label = torch.zeros_like(data['label']).cuda()
                    prev_real = average_tensor

            data['label'] = data['label'][:, :1]
            assert data['label'].shape[1] == 1

            losses, generated = model(Variable(data['label']).cuda(),
                                      Variable(data['inst'].cuda()),
                                      Variable(data['image'].cuda()),
                                      Variable(data['feat'].cuda()),
                                      prev_label,
                                      prev_generated,
                                      prev_real,
                                      infer=save_gen or \
                                      (not config.no_temporal_smoothing))
            # average=average_tensor)

            # sum per device losses
            losses = [
                torch.mean(x) if not isinstance(x, int) else x for x in losses
            ]

            loss_dict = dict(zip(model.module.loss_names, losses))

            # calculate final loss scalar
            loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5
            loss_G = loss_dict['G_GAN'] + loss_dict.get('G_GAN_Feat', 0) + \
                loss_dict.get('G_VGG', 0)

            # update generator weights\n",
            model.module.optimizer_G.zero_grad()

            if config.fp16:
                with amp.scale_loss(loss_G, optimizer_G) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss_G.backward()

            model.module.optimizer_G.step()

            # update discriminator weights\n",
            model.module.optimizer_D.zero_grad()
            if config.fp16:
                with amp.scale_loss(loss_D, optimizer_D) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss_D.backward()
            model.module.optimizer_D.step()

            if (not config.distributed) or torch.distributed.get_rank() == 0:
                if (step +
                        1) % config.print_freq == 0 or step == total_steps - 1:
                    logger.info("Train: [{:2d}/{}] Step {:03d}/{:03d}".format(
                        epoch + 1, config.epochs, i,
                        len(data_set) - 1))
                    logger.info("Loss D: {},  Loss G {}, Loss VGG {}".format(
                        loss_D.item(), loss_dict['G_GAN'].item(),
                        loss_dict['G_VGG'].item()))

                if save_gen:
                    visuals = OrderedDict([
                        ('input_label',
                         util.tensor2label(data['label'][0], config.label_nc)),
                        ('synthesized_image',
                         util.tensor2im(generated.data[0])),
                        ('real_image', util.tensor2im(data['image'][0]))
                    ])

                    if not config.no_temporal_smoothing:
                        visuals['prev_label'] = util.tensor2label(
                            prev_label[0], config.label_nc)
                        visuals['prev_real'] = util.tensor2im(prev_real[0])
                        visuals['prev_generated'] = util.tensor2im(
                            prev_generated.data[0])

                    visualizer.display_current_results(visuals, epoch,
                                                       total_steps)

                if (step + 1) % config.save_latest_freq == 0 or \
                        step == total_steps - 1:
                    model.module.save('latest')
                    model.module.save(epoch)

            if not config.no_temporal_smoothing:
                prev_generated = generated.cuda()
                prev_real = data['image'].detach().cuda()
                prev_label = data['label'].detach().cuda()

            step += 1

        ### train the entire network after certain iterations
        if (config.niter_fix_global != 0) and (epoch
                                               == config.niter_fix_global):
            model.module.update_fixed_params()

        ### linearly decay learning rate after certain iterations
        if epoch > config.niter:
            model.module.update_learning_rate()
Пример #13
0
def infer(n, image_label_path, image_inst_path):
    opt = TestOptions().parse(save=False)
    opt.nThreads = 1  # test code only supports nThreads = 1
    opt.batchSize = 1  # test code only supports batchSize = 1
    opt.serial_batches = True  # no shuffle
    opt.no_flip = True  # no flip
    opt.name = "label2city_1024p"
    opt.netG = "local"
    opt.ngf = 32
    opt.resize_or_crop = "none"

    data_loader = CreateOneDataLoader(opt)
    dataset = data_loader.load_data(image_label_path, image_inst_path)
    visualizer = Visualizer(opt)
    # create website
    #web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch))
    #webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch))

    # test
    if not opt.engine and not opt.onnx:
        model = create_model(opt)
        if opt.data_type == 16:
            model.half()
        elif opt.data_type == 8:
            model.type(torch.uint8)

        if opt.verbose:
            print(model)
    else:
        from run_engine import run_trt_engine, run_onnx

    for i, data in enumerate(dataset):
        if i >= opt.how_many:
            break
        if opt.data_type == 16:
            data['label'] = data['label'].half()
            data['inst'] = data['inst'].half()
        elif opt.data_type == 8:
            data['label'] = data['label'].uint8()
            data['inst'] = data['inst'].uint8()
        if opt.export_onnx:
            print("Exporting to ONNX: ", opt.export_onnx)
            assert opt.export_onnx.endswith(
                "onnx"), "Export model file should end with .onnx"
            torch.onnx.export(model, [data['label'], data['inst']],
                              opt.export_onnx,
                              verbose=True)
            exit(0)
        minibatch = 1
        if opt.engine:
            generated = run_trt_engine(opt.engine, minibatch,
                                       [data['label'], data['inst']])
        elif opt.onnx:
            generated = run_onnx(opt.onnx, opt.data_type, minibatch,
                                 [data['label'], data['inst']])
        else:
            generated = model.inference(data['label'], data['inst'])

        visuals = OrderedDict([
            ('input_label', util.tensor2label(data['label'][0], opt.label_nc)),
            ('synthesized_image', util.tensor2im(generated.data[0]))
        ])
        img_path = data['path']
        print('process image... %s' % img_path)
        visualizer.save_image(visuals, n)
Пример #14
0
def getEnlighten(input_image, cFlag):

    opt = Namespace(D_P_times2=False,
                    IN_vgg=False,
                    aspect_ratio=1.0,
                    batchSize=1,
                    checkpoints_dir=baseLoc + 'weights/',
                    dataroot='test_dataset',
                    dataset_mode='unaligned',
                    display_id=1,
                    display_port=8097,
                    display_single_pane_ncols=0,
                    display_winsize=256,
                    fcn=0,
                    fineSize=256,
                    gpu_ids=[0],
                    high_times=400,
                    how_many=50,
                    hybrid_loss=False,
                    identity=0.0,
                    input_linear=False,
                    input_nc=3,
                    instance_norm=0.0,
                    isTrain=False,
                    l1=10.0,
                    lambda_A=10.0,
                    lambda_B=10.0,
                    latent_norm=False,
                    latent_threshold=False,
                    lighten=False,
                    linear=False,
                    linear_add=False,
                    loadSize=286,
                    low_times=200,
                    max_dataset_size='inf',
                    model='single',
                    multiply=False,
                    nThreads=1,
                    n_layers_D=3,
                    n_layers_patchD=3,
                    name='enlightening',
                    ndf=64,
                    new_lr=False,
                    ngf=64,
                    no_dropout=True,
                    no_flip=True,
                    no_vgg_instance=False,
                    noise=0,
                    norm='instance',
                    norm_attention=False,
                    ntest='inf',
                    output_nc=3,
                    patchD=False,
                    patchD_3=0,
                    patchSize=64,
                    patch_vgg=False,
                    phase='test',
                    resize_or_crop='no',
                    results_dir='./results/',
                    self_attention=True,
                    serial_batches=True,
                    skip=1.0,
                    syn_norm=False,
                    tanh=False,
                    times_residual=True,
                    use_avgpool=0,
                    use_mse=False,
                    use_norm=1.0,
                    use_ragan=False,
                    use_wgan=0.0,
                    vary=1,
                    vgg=0,
                    vgg_choose='relu5_3',
                    vgg_maxpooling=False,
                    vgg_mean=False,
                    which_direction='AtoB',
                    which_epoch='200',
                    which_model_netD='basic',
                    which_model_netG='sid_unet_resize',
                    cFlag=cFlag)

    im = cv2.cvtColor(input_image, cv2.COLOR_RGB2BGR)
    transform = get_transform(opt)
    A_img = transform(im)
    r, g, b = A_img[0] + 1, A_img[1] + 1, A_img[2] + 1
    A_gray = 1. - (0.299 * r + 0.587 * g + 0.114 * b) / 2.
    A_gray = torch.unsqueeze(A_gray, 0)
    data = {
        'A': A_img.unsqueeze(0),
        'B': A_img.unsqueeze(0),
        'A_gray': A_gray.unsqueeze(0),
        'input_img': A_img.unsqueeze(0),
        'A_paths': 'A_path',
        'B_paths': 'B_path'
    }

    model = create_model(opt)
    model.set_input(data)
    visuals = model.predict()
    out = visuals['fake_B'].astype(np.uint8)
    out = cv2.cvtColor(out, cv2.COLOR_BGR2RGB)
    # cv2.imwrite("/Users/kritiksoman/PycharmProjects/new/out.png", out)
    return out
Пример #15
0
def test(opt):
    opt.nThreads = 1  # test code only supports nThreads = 1
    opt.batchSize = 1  # test code only supports batchSize = 1
    opt.serial_batches = True  # no shuffle
    opt.no_flip = True  # no flip
    opt.no_instance = True

    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    visualizer = Visualizer(opt)
    # create website
    web_dir = os.path.join(opt.results_dir,
                           '%s_%s' % (opt.phase, opt.which_epoch))
    webpage = html.HTML(
        web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' %
        (opt.name, opt.phase, opt.which_epoch))

    model = create_model(opt)

    ssim_sum = 0
    psnr_sum = 0
    mse_sum = 0
    scores = np.zeros((len(dataset), 3))
    indices = np.array(choose_k_images(len(dataset), 100))
    saved_scores = np.zeros((len(indices), 3))

    for i, data in enumerate(dataset):
        generated = model.inference(data['label'], data['inst'], data['image'])
        visuals = OrderedDict([
            ('input_label', util.tensor2label(data['label'][0], opt.label_nc)),
            ('synthesized_image', util.tensor2im(generated.data[0])),
            ('real_image', util.tensor2im(data['image'][0]))
        ])
        img_path = data['path']
        scores = save_scores(scores, data, generated, i)

        if i in indices:
            # save to their visualizerx
            print('process image... %s' % img_path)
            visualizer.save_images(webpage, visuals, img_path)
            #save to my files
            save_img(visuals, opt, img_path[0])
            saved_scores[np.where(indices == i)[0]] = scores[i]
            print(i)

    webpage.save()

    dic_saved = {"viz matrix": saved_scores, "mean": saved_scores.mean(axis=0)}
    with open(opt.viz_dir + "/saved_results.pkl", 'wb') as f:
        pickle.dump(dic_saved, f)

    mean_scores = scores.mean(axis=0)
    dic = {"scores matrix": scores, "mean": mean_scores}
    with open(opt.results_dir + "results.pkl", 'wb') as f:
        pickle.dump(dic, f)

    num_gpus = torch.cuda.device_count()
    for gpu_id in range(num_gpus):
        torch.cuda.set_device(gpu_id)
        torch.cuda.empty_cache()
    torch.cuda.set_device(0)
    return mean_scores
Пример #16
0
def setup(opts):
    opt.name = opts['checkpoints_root'].split('/')[-1]
    opt.checkpoints_dir = os.path.join(opts['checkpoints_root'], '..')
    model = create_model(opt)
    return model
Пример #17
0
def main():
    # Create train dataset
    train_set_opt = opt.datasets[0]
    train_set = create_dataset(train_set_opt)
    train_size = int(math.ceil(len(train_set) / train_set_opt.batch_size))
    print('Number of train images: %d batches of size %d' % (train_size, train_set_opt.batch_size))
    total_iters = int(opt.train.niter)
    total_epoches = int(math.ceil(total_iters / train_size))
    print('Total epoches needed: %d' % total_epoches)

    # Create val dataset
    val_set_opt = opt.datasets[1]
    val_set = create_dataset(val_set_opt)
    val_size = len(val_set)
    print('Number of val images: %d' % val_size)

    # Create dataloader
    train_loader = create_dataloader(train_set, train_set_opt)
    val_loader = create_dataloader(val_set, val_set_opt)

    # Create model
    model = create_model(opt)

    # Create binarization module
    import bin
    global bin_op
    bin_op = bin.BinOp(model.netG)
    
    model.train()

    # Create logger
    logger = Logger(opt)

    current_step = 0
    need_make_val_dir = True
    start_time = time.time()
    for epoch in range(total_epoches):
        for i, train_data in enumerate(train_loader):
            current_step += 1
            if current_step > total_iters:
                break

            train_start_time = time.time()
            # Training
            model.feed_data(train_data)

            # optimize_parameters 함수를 분할해서 대체함 for binarization
            # model.optimize_parameters(current_step)
            bin_op.binarization()
            model.forward_G()
            model.optimizer_G.zero_grad()
            model.backward_G()
            bin_op.restore()
            bin_op.updateBinaryGradWeight()
            model.optimizer_G.step()

            train_duration = time.time() - train_start_time

            if current_step % opt.logger.print_freq == 0:
                losses = model.get_current_losses()
                logger.print_results(losses, epoch, current_step, train_duration, 'loss')

            if current_step % opt.logger.save_checkpoint_freq == 0:
                print('Saving the model at the end of current_step %d' % (current_step))
                model.save(current_step)

            # Validation
            if current_step % opt.train.val_freq == 0:
                validate(val_loader, val_size, model, logger, epoch, current_step)

            model.update_learning_rate(step=current_step, scheme=opt.train.lr_scheme)

        print('End of Epoch %d' % epoch)

    print('Saving the final model')
    model.save('latest')

    print('End of Training \t Time Taken: %d sec' % (time.time() - start_time))
Пример #18
0
# from util.util import tensor2im, get_display_image, one_hot
#
# from val import get_test_result
# #from eval.test_is_2 import get_IS
# import fnmatch

opt = TestOptions().parse(save=False)
opt.nThreads = 1  # test code only supports nThreads = 1
opt.batchSize = 1  # test code only supports batchSize = 1
opt.serial_batches = True  # no shuffle
opt.no_flip = True  # no flip

data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()

model = create_model(opt, 'resNet')
visualizer = Visualizer(opt)
dataset_size = len(data_loader)
print('#testing images = %d' % dataset_size)


def main():
    web_dir = os.path.join(opt.results_dir, opt.name,
                           '%s_%s' % (opt.phase, opt.which_epoch))
    webpage = html.HTML(
        web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' %
        (opt.name, opt.phase, opt.which_epoch))

    for i, data in enumerate(dataset):
        if i >= opt.how_many:
            break
Пример #19
0
opt = TestOptions().parse(save=False)
opt.nThreads = 1  # test code only supports nThreads = 1
opt.batchSize = 1  # test code only supports batchSize = 1
opt.serial_batches = True  # no shuffle
opt.no_flip = True  # no flip

opt.stage = 12  ## choose stage_I_II_dataset.py
data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()

opt.name = "stage_I_gan_ganFeat_noL1_oneD_Parsing_bz50_parsing20_04222"
opt.which_G = "resNet"
opt.stage = 1
opt.which_epoch = 100
model_1 = create_model(opt)

opt.name = "gan_L1_feat_vgg_notv_noparsing_afftps_05102228"
opt.which_G = "wapResNet_v3_afftps"
opt.stage = 2
opt.which_epoch = 45
model_2 = create_model(opt)

visualizer = Visualizer(opt)
dataset_size = len(data_loader)
print('#testing images = %d' % dataset_size)

geo = GeoAPI()
affTnf = GeometricTnf(geometric_model='affine', use_cuda=False)
tpsTnf = GeometricTnf(geometric_model='tps', use_cuda=False)
Пример #20
0
def train():
    opt = TrainOptions().parse()
    if opt.debug:
        opt.display_freq = 1
        opt.print_freq = 1
        opt.nThreads = 1

    ### initialize dataset
    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('#training frames = %d' % dataset_size)

    ### initialize models
    models = create_model(opt)
    modelG, modelD, flowNet, optimizer_G, optimizer_D, optimizer_D_T = create_optimizer(
        opt, models)

    ### set parameters
    n_gpus, tG, tD, tDB, s_scales, t_scales, input_nc_1, input_nc_2, \
        start_epoch, epoch_iter, print_freq, total_steps, iter_path = init_params_composer(opt, modelG, modelD, data_loader)
    visualizer = Visualizer(opt)

    ### real training starts here
    for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
        for idx, data in enumerate(dataset, start=epoch_iter):
            if total_steps % print_freq == 0:
                iter_start_time = time.time()
            total_steps += opt.batchSize
            epoch_iter += opt.batchSize

            # whether to collect output images
            save_fake = total_steps % opt.display_freq == 0
            n_frames_total, n_frames_load, t_len = data_loader.dataset.init_data_params(
                data, n_gpus, tG)
            fake_SI_prev_last, frames_all = data_loader.dataset.init_data(
                t_scales)

            for i in range(0, n_frames_total, n_frames_load):
                input_TParsing, input_TFG, input_SPose, input_SParsing, input_SFG, input_BG, input_SFG_full, input_SI = data_loader.dataset.prepare_data_composer(
                    data, i)

                ###################################### Forward Pass ##########################
                ####### generator
                fake_SI, fake_SI_raw, fake_sd, fake_SFG_full, fake_SFG_res, flow, weight, real_input_T, real_input_S, real_input_SFG, real_input_BG, real_SIp, real_SFG_fullp, fake_SI_last \
                    = modelG(input_TParsing, input_TFG, input_SPose, input_SParsing, input_SFG, input_BG, input_SFG_full, input_SI, fake_SI_prev_last)

                ####### discriminator
                ### individual frame discriminator
                real_SI_prev, real_SI = real_SIp[:, :
                                                 -1], real_SIp[:,
                                                               1:]  # the collection of previous and current real frames
                real_SFG_full_prev, real_SFG_full = real_SFG_fullp[:, :
                                                                   -1], real_SFG_fullp[:,
                                                                                       1:]
                #flow_ref, conf_ref = flowNet(real_SI, real_SI_prev)       # reference flows and confidences
                flow_ref, conf_ref = flowNet(
                    real_SFG_full,
                    real_SFG_full_prev)  # reference flows and confidences
                fake_SI_prev = modelG.module.compute_fake_B_prev(
                    real_SI_prev, fake_SI_prev_last, fake_SI)
                fake_SI_prev_last = fake_SI_last

                real_input_BG_flag = data['BG_flag']
                losses = modelD(0, [
                    real_SI, real_SFG_full, fake_SI, fake_SI_raw,
                    fake_SFG_full, fake_SFG_res, real_input_T, real_input_S,
                    real_input_SFG, real_input_BG, real_input_BG_flag,
                    real_SI_prev, fake_SI_prev, real_SFG_full_prev, flow,
                    weight, flow_ref, conf_ref
                ])
                losses = [
                    torch.mean(x) if x is not None else 0 for x in losses
                ]
                loss_dict = dict(zip(modelD.module.loss_names, losses))

                ### temporal discriminator
                # get skipped frames for each temporal scale
                frames_all, frames_skipped = modelD.module.get_all_skipped_frames(frames_all, \
                        real_SI, fake_SI, flow_ref, conf_ref, t_scales, tD, n_frames_load, i, flowNet)

                # run discriminator for each temporal scale
                loss_dict_T = []
                for s in range(t_scales):
                    if frames_skipped[0][s] is not None:
                        losses = modelD(s + 1, [
                            frame_skipped[s]
                            for frame_skipped in frames_skipped
                        ])
                        losses = [
                            torch.mean(x) if not isinstance(x, int) else x
                            for x in losses
                        ]
                        loss_dict_T.append(
                            dict(zip(modelD.module.loss_names_T, losses)))

                # collect losses
                loss_G, loss_D, loss_D_T, t_scales_act = modelD.module.get_losses(
                    loss_dict, loss_dict_T, t_scales)

                ###################################### Backward Pass #################################
                # update generator weights
                loss_backward(opt, loss_G, optimizer_G)

                # update individual discriminator weights
                loss_backward(opt, loss_D, optimizer_D)

                # update temporal discriminator weights
                for s in range(t_scales_act):
                    loss_backward(opt, loss_D_T[s], optimizer_D_T[s])

                if i == 0:
                    fake_SI_first = fake_SI[
                        0, 0]  # the first generated image in this sequence

            if opt.debug:
                call([
                    "nvidia-smi", "--format=csv",
                    "--query-gpu=memory.used,memory.free"
                ])

            ############## Display results and errors ##########
            ### print out errors
            if total_steps % print_freq == 0:
                t = (time.time() - iter_start_time) / print_freq
                errors = {
                    k: v.data.item() if not isinstance(v, int) else v
                    for k, v in loss_dict.items()
                }
                for s in range(len(loss_dict_T)):
                    errors.update({
                        k + str(s):
                        v.data.item() if not isinstance(v, int) else v
                        for k, v in loss_dict_T[s].items()
                    })
                visualizer.print_current_errors_new(epoch, epoch_iter, errors,
                                                    modelD.module.loss_names,
                                                    t)
                visualizer.plot_current_errors(errors, total_steps)

            ### display output images
            if save_fake:
                visuals = util.save_all_tensors_composer(
                    opt, real_input_T, real_input_S, real_input_SFG,
                    real_input_BG, fake_SI, fake_SI_raw, fake_SI_first,
                    fake_SFG_full, fake_SFG_res, fake_sd, real_SI,
                    real_SFG_full, flow_ref, conf_ref, flow, weight, modelD)
                visualizer.display_current_results(visuals, epoch, total_steps)

            ### save latest model
            save_models(opt, epoch, epoch_iter, total_steps, visualizer,
                        iter_path, modelG, modelD)
            if epoch_iter > dataset_size - opt.batchSize:
                epoch_iter = 0
                break

        # end of epoch
        iter_end_time = time.time()
        visualizer.vis_print('End of epoch %d / %d \t Time Taken: %d sec' %
                             (epoch, opt.niter + opt.niter_decay,
                              time.time() - epoch_start_time))

        ### save model for this epoch and update model params
        save_models(opt,
                    epoch,
                    epoch_iter,
                    total_steps,
                    visualizer,
                    iter_path,
                    modelG,
                    modelD,
                    end_of_epoch=True)
        update_models(opt, epoch, modelG, modelD, data_loader)
Пример #21
0
def main():
    os.makedirs('sample', exist_ok=True)
    opt = TestOptions().parse()

    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('# Inference images = %d' % dataset_size)

    model = create_model(opt)

    for i, data in enumerate(dataset):

        # add gaussian noise channel
        # wash the label
        t_mask = torch.FloatTensor(
            (data['label'].cpu().numpy() == 7).astype(np.float))
        #
        # data['label'] = data['label'] * (1 - t_mask) + t_mask * 4
        mask_clothes = torch.FloatTensor(
            (data['label'].cpu().numpy() == 4).astype(np.int))
        mask_fore = torch.FloatTensor(
            (data['label'].cpu().numpy() > 0).astype(np.int))
        mask_back = torch.FloatTensor(
            (data['label'].cpu().numpy() == 0).astype(np.int))

        img_fore = data['image'] * mask_fore
        img_back = data['image'] * mask_back

        all_clothes_label = changearm(data['label'])

        ############## Forward Pass ######################
        fake_image, warped_cloth, refined_cloth = model(
            Variable(data['label'].cuda()), Variable(data['edge'].cuda()),
            Variable(img_fore.cuda()), Variable(mask_clothes.cuda()),
            Variable(data['color'].cuda()), Variable(all_clothes_label.cuda()),
            Variable(data['image'].cuda()), Variable(data['pose'].cuda()),
            Variable(data['image'].cuda()), Variable(mask_fore.cuda()))

        # Restore Background
        fake_image = fake_image * mask_fore.cuda() + img_back.cuda()

        # make output folders
        output_dir = os.path.join(opt.results_dir, opt.phase)
        fake_image_dir = os.path.join(output_dir, 'try-on')
        os.makedirs(fake_image_dir, exist_ok=True)
        warped_cloth_dir = os.path.join(output_dir, 'warped_cloth')
        os.makedirs(warped_cloth_dir, exist_ok=True)
        refined_cloth_dir = os.path.join(output_dir, 'refined_cloth')
        os.makedirs(refined_cloth_dir, exist_ok=True)

        # save output
        for j in range(opt.batchSize):
            print("Saving", data['name'][j])
            util.save_tensor_as_image(
                fake_image[j], os.path.join(fake_image_dir, data['name'][j]))
            util.save_tensor_as_image(
                warped_cloth[j], os.path.join(warped_cloth_dir,
                                              data['name'][j]))
            util.save_tensor_as_image(
                refined_cloth[j],
                os.path.join(refined_cloth_dir, data['name'][j]))
import os
from options.test_options import TestOptions
from data.data_loader import CreateDataLoader
from models.models import create_model
from util.visualizer import Visualizer
from util import html

opt = TestOptions().parse()
opt.nThreads = 1   # test code only supports nThreads = 1
opt.batchSize = 1  # test code only supports batchSize = 1
opt.serial_batches = True  # no shuffle
opt.no_flip = True  # no flip

data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()
model = create_model(opt)
visualizer = Visualizer(opt)
# create website
web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch))
webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch))
# test
for i, data in enumerate(dataset):
    if i >= opt.how_many:
        break
    model.set_input(data)
    model.test()
    visuals = model.get_current_visuals()
    img_path = model.get_image_paths()
    print('process image... %s' % img_path)
    visualizer.save_images(webpage, visuals, img_path)
Пример #23
0
from torch.autograd import Variable
from collections import OrderedDict
from options.test_options import TestOptions
from data.data_loader import CreateDataLoader
from models.models import create_model
import util.util as util
from util.visualizer import Visualizer

opt = TestOptions().parse(save=False)
opt.nThreads = 1
opt.batchSize = 1
opt.serial_batches = True

visualizer = Visualizer(opt)

modelG = create_model(opt)

data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()
dataset_size = len(data_loader)

print('Generating %d frames' % dataset_size)

save_dir = os.path.join(opt.results_dir, opt.name, opt.which_epoch + '_epoch',
                        opt.phase)

total_distance, total_pixels = 0, 0
mtotal_distance, mtotal_pixels = 0, 0
mouth_total_distance, mouth_total_pixels = 0, 0

for i, data in enumerate(dataset):