Exemple #1
0
 def get_current_visuals(self):
     real_A_img, real_A_prior = util.tensor2im(self.real_A.data)
     fake_B = util.tensor2im(self.fake_B.data)
     real_B = util.tensor2im(self.real_B.data)
     if self.opt.output_nc == 1:
         fake_B_postprocessed = util.postprocess_parsing(fake_B, self.isTrain)
         fake_B_color = util.paint_color(fake_B_postprocessed)
         real_B_color = util.paint_color(util.postprocess_parsing(real_B, self.isTrain))
     if self.opt.output_nc == 1:
         return OrderedDict([
             ('real_A_img', real_A_img),
             ('real_A_prior', real_A_prior),
             ('fake_B', fake_B),
             ('fake_B_postprocessed', fake_B_postprocessed),
             ('fake_B_color', fake_B_color),
             ('real_B', real_B),
             ('real_B_color', real_B_color)]
         )
     else:
         return OrderedDict([
             ('real_A_img', real_A_img),
             ('real_A_prior', real_A_prior),
             ('fake_B', fake_B),
             ('real_B', real_B)]
         )
Exemple #2
0
 def style_forward(self, click_pt, style_id=-1):           
     if click_pt is None:            
         self.fake_image = util.tensor2im(self.single_forward(self.net_input, self.feat_map))
         self.crop = None
         self.mask = None        
     else:                       
         instToChange = int(self.object_map[0, 0, click_pt[0], click_pt[1]])
         self.instToChange = instToChange
         label = instToChange if instToChange < 1000 else instToChange//1000        
         self.feat = self.features_clustered[label]
         self.fake_image = []
         self.mask = self.object_map == instToChange
         idx = self.mask.nonzero()
         self.get_crop_region(idx)            
         if idx.size():                
             if style_id == -1:
                 (min_y, min_x, max_y, max_x) = self.crop
                 ### original
                 for cluster_idx in range(self.opt.multiple_output):
                     self.set_features(idx, self.feat, cluster_idx)
                     fake_image = self.single_forward(self.net_input, self.feat_map)
                     fake_image = util.tensor2im(fake_image[:,min_y:max_y,min_x:max_x])
                     self.fake_image.append(fake_image)    
                 """### To speed up previewing different style results, either crop or downsample the label maps
                 if instToChange > 1000:
                     (min_y, min_x, max_y, max_x) = self.crop                                                
                     ### crop                                                
                     _, _, h, w = self.net_input.size()
                     offset = 512
                     y_start, x_start = max(0, min_y-offset), max(0, min_x-offset)
                     y_end, x_end = min(h, (max_y + offset)), min(w, (max_x + offset))
                     y_region = slice(y_start, y_start+(y_end-y_start)//16*16)
                     x_region = slice(x_start, x_start+(x_end-x_start)//16*16)
                     net_input = self.net_input[:,:,y_region,x_region]                    
                     for cluster_idx in range(self.opt.multiple_output):  
                         self.set_features(idx, self.feat, cluster_idx)
                         fake_image = self.single_forward(net_input, self.feat_map[:,:,y_region,x_region])                            
                         fake_image = util.tensor2im(fake_image[:,min_y-y_start:max_y-y_start,min_x-x_start:max_x-x_start])
                         self.fake_image.append(fake_image)
                 else:
                     ### downsample
                     (min_y, min_x, max_y, max_x) = [crop//2 for crop in self.crop]                    
                     net_input = self.net_input[:,:,::2,::2]                    
                     size = net_input.size()
                     net_input_batch = net_input.expand(self.opt.multiple_output, size[1], size[2], size[3])             
                     for cluster_idx in range(self.opt.multiple_output):  
                         self.set_features(idx, self.feat, cluster_idx)
                         feat_map = self.feat_map[:,:,::2,::2]
                         if cluster_idx == 0:
                             feat_map_batch = feat_map
                         else:
                             feat_map_batch = torch.cat((feat_map_batch, feat_map), dim=0)
                     fake_image_batch = self.single_forward(net_input_batch, feat_map_batch)
                     for i in range(self.opt.multiple_output):
                         self.fake_image.append(util.tensor2im(fake_image_batch[i,:,min_y:max_y,min_x:max_x]))"""
                                     
             else:
                 self.set_features(idx, self.feat, style_id)
                 self.cluster_indices[label] = style_id
                 self.fake_image = util.tensor2im(self.single_forward(self.net_input, self.feat_map))        
    def get_current_visuals(self):
        fake_B_audio = self.audio_gen_fakes.view(-1, self.opt.sequence_length, self.opt.image_channel_size, self.opt.image_size, self.opt.image_size)
        real_A = util.tensor2im(self.real_A.data)
        oderdict = OrderedDict([('real_A', real_A)])
        fake_audio_B = {}
        fake_image_B = {}
        for i in range(self.opt.sequence_length):
            fake_audio_B[i] = util.tensor2im(fake_B_audio[:, i, :, :, :].data)
            oderdict['fake_audio_B_' + str(i)] = fake_audio_B[i]

        return oderdict
    def forward(self, in0, in1):
        assert(in0.size()[0]==1) # currently only supports batchSize 1

        if(self.colorspace=='RGB'):
            value = util.dssim(1.*util.tensor2im(in0.data), 1.*util.tensor2im(in1.data), range=255.).astype('float')
        elif(self.colorspace=='Lab'):
            value = util.dssim(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)), 
                util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
        ret_var = Variable( torch.Tensor((value,) ) )
        if(self.use_gpu):
            ret_var = ret_var.cuda()
        return ret_var
    def get_current_visuals(self):
        zoom_factor = 256/self.var_ref.data.size()[2]

        ref_img = util.tensor2im(self.var_ref.data)
        p0_img = util.tensor2im(self.var_p0.data)
        p1_img = util.tensor2im(self.var_p1.data)

        ref_img_vis = zoom(ref_img,[zoom_factor, zoom_factor, 1],order=0)
        p0_img_vis = zoom(p0_img,[zoom_factor, zoom_factor, 1],order=0)
        p1_img_vis = zoom(p1_img,[zoom_factor, zoom_factor, 1],order=0)

        return OrderedDict([('ref', ref_img_vis),
                            ('p0', p0_img_vis),
                            ('p1', p1_img_vis)])
Exemple #6
0
    def change_labels(self, click_src, click_tgt): 
        y_src, x_src = click_src[0], click_src[1]
        y_tgt, x_tgt = click_tgt[0], click_tgt[1]
        label_src = int(self.label_map[0, 0, y_src, x_src])
        inst_src = self.inst_map[0, 0, y_src, x_src]
        label_tgt = int(self.label_map[0, 0, y_tgt, x_tgt])
        inst_tgt = self.inst_map[0, 0, y_tgt, x_tgt]

        idx_src = (self.inst_map == inst_src).nonzero()         
        # need to change 3 things: label map, instance map, and feature map
        if idx_src.shape:
            # backup current maps
            self.backup_current_state() 

            # change both the label map and the network input
            self.label_map[idx_src[:,0], idx_src[:,1], idx_src[:,2], idx_src[:,3]] = label_tgt
            self.net_input[idx_src[:,0], idx_src[:,1] + label_src, idx_src[:,2], idx_src[:,3]] = 0
            self.net_input[idx_src[:,0], idx_src[:,1] + label_tgt, idx_src[:,2], idx_src[:,3]] = 1                                    
            
            # update the instance map (and the network input)
            if inst_tgt > 1000:
                # if different instances have different ids, give the new object a new id
                tgt_indices = (self.inst_map > label_tgt * 1000) & (self.inst_map < (label_tgt+1) * 1000)
                inst_tgt = self.inst_map[tgt_indices].max() + 1
            self.inst_map[idx_src[:,0], idx_src[:,1], idx_src[:,2], idx_src[:,3]] = inst_tgt
            self.net_input[:,-1,:,:] = self.get_edges(self.inst_map)

            # also copy the source features to the target position      
            idx_tgt = (self.inst_map == inst_tgt).nonzero()    
            if idx_tgt.shape:
                self.copy_features(idx_src, idx_tgt[0,:])

        self.fake_image = util.tensor2im(self.single_forward(self.net_input, self.feat_map))
    def get_current_visuals(self, testing=False):
        if not testing:
            self.visuals = [self.real_A, self.fake_B, self.rec_A, self.real_B, self.fake_A, self.rec_B]
            self.labels = ['real_A', 'fake_B', 'rec_A', 'real_B', 'fake_A', 'rec_B']
            if self.opt.lambda_identity > 0.0:
                self.visuals += [self.idt_A, self.idt_B]
                self.labels += ['idt_A', 'idt_B']

        images = [util.tensor2im(v.data) for v in self.visuals]
        return OrderedDict(zip(self.labels, images))
 def get_current_visuals(self):
     real_A = util.tensor2im(self.input_A)
     fake_B = util.tensor2im(self.fake_B)
     rec_A = util.tensor2im(self.rec_A)
     real_B = util.tensor2im(self.input_B)
     fake_A = util.tensor2im(self.fake_A)
     rec_B = util.tensor2im(self.rec_B)
     ret_visuals = OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A),
                                ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B)])
     if self.opt.isTrain and self.opt.lambda_identity > 0.0:
         ret_visuals['idt_A'] = util.tensor2im(self.idt_A)
         ret_visuals['idt_B'] = util.tensor2im(self.idt_B)
     return ret_visuals
def sketch(img):
    img = pil2cv(img) # cv2.imread(image_path)
    img = img / 255.
    h, w = img.shape[0:2]
    img = np.transpose(img, (2, 0, 1))
    img = np.expand_dims(img, 0)
    img = torch.from_numpy(img).float() #.to(device)
    data = {'A_paths': '', 'A': img, 'B': img }
    model.set_input(data)
    model.test()
    output = util.tensor2im(model.fake_B)
    return cv2pil(output)
Exemple #10
0
    def get_current_visuals(self):
        real_A = util.tensor2im(self.real_A.data)
        fake_B = util.tensor2im(self.fake_B.data)
        rec_A  = util.tensor2im(self.rec_A.data)
        real_B = util.tensor2im(self.real_B.data)
        fake_A = util.tensor2im(self.fake_A.data)
        rec_B  = util.tensor2im(self.rec_B.data)

        AE_fake_A = util.tensor2im(self.AEfakeA.view(1,1,28,28).data)
        AE_fake_B = util.tensor2im(self.AEfakeB.view(1,1,28,28).data)


        if self.opt.identity > 0.0:
            idt_A = util.tensor2im(self.idt_A.data)
            idt_B = util.tensor2im(self.idt_B.data)
            return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), ('idt_B', idt_B),
                                ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B), ('idt_A', idt_A),
                                ('AE_fake_A', AE_fake_A), ('AE_fake_B', AE_fake_B)])
        else:
            return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A),
                                ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B),
                                ('AE_fake_A', AE_fake_A), ('AE_fake_B', AE_fake_B)])
Exemple #11
0
    def add_objects(self, click_src, label_tgt, mask, style_id=0):
        y, x = click_src[0], click_src[1]
        mask = np.transpose(mask, (2, 0, 1))[np.newaxis,...]        
        idx_src = torch.from_numpy(mask).cuda().nonzero()        
        idx_src[:,2] += y
        idx_src[:,3] += x

        # backup current maps
        self.backup_current_state()

        # update label map
        self.label_map[idx_src[:,0], idx_src[:,1], idx_src[:,2], idx_src[:,3]] = label_tgt        
        for k in range(self.opt.label_nc):
            self.net_input[idx_src[:,0], idx_src[:,1] + k, idx_src[:,2], idx_src[:,3]] = 0
        self.net_input[idx_src[:,0], idx_src[:,1] + label_tgt, idx_src[:,2], idx_src[:,3]] = 1            

        # update instance map
        self.inst_map[idx_src[:,0], idx_src[:,1], idx_src[:,2], idx_src[:,3]] = label_tgt
        self.net_input[:,-1,:,:] = self.get_edges(self.inst_map)
                
        # update feature map
        self.set_features(idx_src, self.feat, style_id)                
        
        self.fake_image = util.tensor2im(self.single_forward(self.net_input, self.feat_map))
Exemple #12
0
    def add_strokes(self, click_src, label_tgt, bw, save):
        # get the region of the new strokes (bw is the brush width)        
        size = self.net_input.size()
        h, w = size[2], size[3]
        idx_src = torch.LongTensor(bw**2, 4).fill_(0)
        for i in range(bw):
            idx_src[i*bw:(i+1)*bw, 2] = min(h-1, max(0, click_src[0]-bw//2 + i))
            for j in range(bw):
                idx_src[i*bw+j, 3] = min(w-1, max(0, click_src[1]-bw//2 + j))
        idx_src = idx_src.cuda()
        
        # again, need to update 3 things
        if idx_src.shape:
            # backup current maps
            if save:
                self.backup_current_state()

            # update the label map (and the network input) in the stroke region            
            self.label_map[idx_src[:,0], idx_src[:,1], idx_src[:,2], idx_src[:,3]] = label_tgt
            for k in range(self.opt.label_nc):
                self.net_input[idx_src[:,0], idx_src[:,1] + k, idx_src[:,2], idx_src[:,3]] = 0
            self.net_input[idx_src[:,0], idx_src[:,1] + label_tgt, idx_src[:,2], idx_src[:,3]] = 1                 

            # update the instance map (and the network input)
            self.inst_map[idx_src[:,0], idx_src[:,1], idx_src[:,2], idx_src[:,3]] = label_tgt
            self.net_input[:,-1,:,:] = self.get_edges(self.inst_map)
            
            # also update the features if available
            if self.opt.instance_feat:                                            
                feat = self.features_clustered[label_tgt]
                #np.random.seed(label_tgt+1)   
                #cluster_idx = np.random.randint(0, feat.shape[0])
                cluster_idx = self.cluster_indices[label_tgt]
                self.set_features(idx_src, feat, cluster_idx)                                                  
        
        self.fake_image = util.tensor2im(self.single_forward(self.net_input, self.feat_map))
from models.models import create_model
import os
import util.util as util
from torch.autograd import Variable
import torch.nn as nn

opt = TrainOptions().parse()
opt.nThreads = 1
opt.batchSize = 1 
opt.serial_batches = True 
opt.no_flip = True
opt.instance_feat = True

name = 'features'
save_path = os.path.join(opt.checkpoints_dir, opt.name)

############ Initialize #########
data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()
dataset_size = len(data_loader)
model = create_model(opt)
util.mkdirs(os.path.join(opt.dataroot, opt.phase + '_feat'))

######## Save precomputed feature maps for 1024p training #######
for i, data in enumerate(dataset):
	print('%d / %d images' % (i+1, dataset_size)) 
	feat_map = model.module.netE.forward(Variable(data['image'].cuda(), volatile=True), data['inst'].cuda())
	feat_map = nn.Upsample(scale_factor=2, mode='nearest')(feat_map)
	image_numpy = util.tensor2im(feat_map.data[0])
	save_path = data['path'][0].replace('/train_label/', '/train_feat/')
	util.save_image(image_numpy, save_path)
            if total_steps % opt.print_freq == 0:
                errors = {}
                if torch.__version__[0] == '1':
                    errors = {k: v.item() if not isinstance(v, int) else v for k, v in loss_dict.items()}
                else:
                    errors = {k: v.data[0] if not isinstance(v, int) else v for k, v in loss_dict.items()}
                t = (time.time() - iter_start_time) / opt.batchSize
                visualizer.print_current_errors(epoch, epoch_iter, errors, t)
                visualizer.plot_current_errors(errors, total_steps)

            ### display output images
            if save_fake:
                syn = generated[0].data[0]
                inputs = torch.cat((data['label'], data['next_label']), dim=3)
                targets = torch.cat((data['image'], data['next_image']), dim=3)
                visuals = OrderedDict([('input_label', util.tensor2im(inputs[0], normalize=False)),
                                           ('synthesized_image', util.tensor2im(syn)),
                                           ('real_image', util.tensor2im(targets[0]))])
                if opt.face_generator: #display face generator on tensorboard
                    miny, maxy, minx, maxx = data['face_coords'][0]
                    res_face = generated[2].data[0]
                    syn_face = generated[1].data[0]
                    preres = generated[3].data[0]
                    visuals= OrderedDict([('input_label', util.tensor2im(inputs[0], normalize=False)),
                                           ('synthesized_image', util.tensor2im(syn)),
                                           ('synthesized_face', util.tensor2im(syn_face)),
                                           ('residual', util.tensor2im(res_face)),
                                           ('real_face', util.tensor2im(data['image'][0][:, miny:maxy, minx:maxx])),
                                           # ('pre_residual', util.tensor2im(preres)),
                                           # ('pre_residual_face', util.tensor2im(preres[:, miny:maxy, minx:maxx])),
                                           ('input_face', util.tensor2im(data['label'][0][:, miny:maxy, minx:maxx], normalize=False)),
Exemple #15
0
    Path(fake_dir).mkdir(parents=True, exist_ok=True)
    real_dir = os.path.join(result_dir, "real")
    Path(real_dir).mkdir(parents=True, exist_ok=True)
    for i, data in enumerate(dataset):
        # if i >= opt.num_test:  # only apply our model to opt.num_test images.
        #     break

        model.set_input(data)  # unpack data from data loader
        model.test()  # run inference
        visuals = model.get_current_visuals()  # get image results
        img_path = model.get_image_paths()  # get image paths

        if i % 5 == 0:  # save images to an HTML file
            print('processing (%04d)-th image... %s' % (i, img_path))

        # save_images(webpage, visuals, img_path,
        #             aspect_ratio=opt.aspect_ratio, width=opt.display_winsize)

        real_im = util.tensor2im(visuals['real'])
        fake_im = util.tensor2im(visuals['fake'])
        util.save_image(real_im,
                        os.path.join(real_dir,
                                     str(i) + ".png"),
                        aspect_ratio=opt.aspect_ratio)
        util.save_image(fake_im,
                        os.path.join(fake_dir,
                                     str(i) + ".png"),
                        aspect_ratio=opt.aspect_ratio)

    # webpage.save()  # save the HTML
Exemple #16
0
        ############## Display results and errors ##########
        ### print out errors
        # if total_steps % opt.print_freq == print_delta:
        #     errors = {k: v.data.item() if not isinstance(v, int) else v for k, v in loss_dict.items()}
        #     t = (time.time() - iter_start_time) / opt.print_freq
        #     visualizer.print_current_errors(epoch, epoch_iter, errors, t)
        #     visualizer.plot_current_errors(errors, total_steps)
        #call(["nvidia-smi", "--format=csv", "--query-gpu=memory.used,memory.free"])

        ### display output images
        if save_fake:
            visuals = OrderedDict([
                ('input_label',
                 util.tensor2label(data['label'][0], opt.label_nc)),
                ('synthesized_image', util.tensor2im(generated.data[0])),
                ('real_image', util.tensor2im(data['image'][0]))
            ])
            visualizer.display_current_results(visuals, epoch, total_steps)

        ### save latest model
        if total_steps % opt.save_latest_freq == save_delta:
            print('saving the latest model (epoch %d, total_steps %d)' %
                  (epoch, total_steps))
            model.module.save('latest')
            np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d')

        if epoch_iter >= dataset_size:
            break

    # end of epoch
Exemple #17
0
 def get_current_visuals(self):
     fake_B = util.tensor2im(self.fake_B)
     return fake_B
Exemple #18
0
    def get_current_visuals(self):
        cond_A = util.tensor2im(self.mask_A)
        cond_AA = util.tensor2im(self.mask_AA)
        cond_B = util.tensor2im(self.mask_B)
        cond_BB = util.tensor2im(self.mask_BB)
        fake_A = util.tensor2im(self.fake_A)
        fake_B = util.tensor2im(self.fake_B)
        fake_AB = util.tensor2im(self.fake_AB)
        fake_BA = util.tensor2im(self.fake_BA)
        fake_AC = util.tensor2im(self.fake_AC)
        fake_BC = util.tensor2im(self.fake_BC)

        ret_visuals = OrderedDict([('fake_A', fake_A), ('fake_B', fake_B),
                                   ('cond_A', cond_A), ('cond_AA', cond_AA),
                                   ('cond_B', cond_B), ('cond_BB', cond_BB),
                                   ('fake_AB', fake_AB), ('fake_BA', fake_BA),
                                   ('fake_AC', fake_AC), ('fake_BC', fake_BC)])
        if not self.opt.isG3:
            ret_visuals['real_A'] = util.tensor2im(self.input_A)
            ret_visuals['real_B'] = util.tensor2im(self.input_B)

        if self.opt.isTrain and self.opt.identity > 0.0:
            ret_visuals['idt_AC'] = util.tensor2im(self.idt_AC)
            ret_visuals['idt_BC'] = util.tensor2im(self.idt_BC)
        return ret_visuals
Exemple #19
0
input_nc = 1 if opt.label_nc != 0 else opt.input_nc

save_dir = os.path.join(opt.results_dir, opt.name,
                        '%s_%s' % (opt.phase, opt.which_epoch))
print('Doing %d frames' % len(dataset))
for i, data in enumerate(dataset):
    if i >= opt.how_many:
        break
    if data['change_seq']:
        model.fake_B_prev = None

    _, _, height, width = data['A'].size()
    A = Variable(data['A']).view(1, -1, input_nc, height, width)
    B = Variable(data['B']).view(1, -1, opt.output_nc, height,
                                 width) if len(data['B'].size()) > 2 else None
    inst = Variable(data['inst']).view(
        1, -1, 1, height, width) if len(data['inst'].size()) > 2 else None
    generated = model.inference(A, B, inst)

    if opt.label_nc != 0:
        real_A = util.tensor2label(generated[1], opt.label_nc)
    else:
        c = 3 if opt.input_nc == 3 else 1
        real_A = util.tensor2im(generated[1][:c], normalize=False)

    visual_list = [('real_A', real_A),
                   ('fake_B', util.tensor2im(generated[0].data[0]))]
    visuals = OrderedDict(visual_list)
    img_path = data['A_path']
    print('process image... %s' % img_path)
    visualizer.save_images(save_dir, visuals, img_path)
Exemple #20
0
def main():
    web_dir = os.path.join(opt.results_dir, opt.name,
                           '%s_%s' % (opt.phase, opt.which_epoch))
    webpage = html.HTML(
        web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' %
        (opt.name, opt.phase, opt.which_epoch))

    for i, data in enumerate(dataset):
        if i >= opt.how_many:
            break

        a_image_tensor = data['a_image_tensor']  # 3
        b_image_tensor = data['b_image_tensor']  # 3
        b_label_tensor = data['b_label_tensor']  # 18
        a_parsing_tensor = data['a_parsing_tensor']  # 1
        b_parsing_tensor = data['b_parsing_tensor']  # 1
        b_label_show_tensor = data['b_label_show_tensor']
        theta_aff = data['theta_aff_tensor']  # 2
        theta_tps = data['theta_tps_tensor']  # 2
        theta_aff_tps = data['theta_aff_tps_tensor']  # 2
        policy_binary = data['policy_binary']  # 1
        a_jpg_path = data['a_jpg_path']
        b_jpg_path = data['b_jpg_path']

        input_tensor = torch.cat([a_image_tensor, b_image_tensor, b_label_tensor, a_parsing_tensor, b_parsing_tensor, \
                                  theta_aff, theta_tps, theta_aff_tps, policy_binary], dim=1)
        input_var = Variable(input_tensor.type(torch.cuda.FloatTensor))
        model.eval()
        fake_b = model.inference(input_var)

        # test_list = [('b_label_show', util.tensor2im(b_label_show_tensor[0])),
        #               ('a_image', util.tensor2im(a_image_tensor[0])),
        #              ('fake_b_parsing', util.tensor2im(
        #                  util.parsingim_2_tensor(b_parsing_tensor[0], opt=opt, parsing_label_nc=opt.parsing_label_nc)[0])),
        #               ('fake_image', util.tensor2im(fake_b.data[0])),
        #               ('b_image', util.tensor2im(b_image_tensor[0]))]

        a_parsing_rgb_tensor = parsingim_2_tensor(
            a_parsing_tensor[0],
            opt=opt,
            parsing_label_nc=opt.parsing_label_nc)
        b_parsing_rgb_tensor = parsingim_2_tensor(
            b_parsing_tensor[0],
            opt=opt,
            parsing_label_nc=opt.parsing_label_nc)

        show_image_tensor_1 = torch.cat(
            (a_image_tensor, b_label_show_tensor, b_image_tensor), dim=3)
        show_image_tensor_2 = torch.cat(
            (a_parsing_rgb_tensor, b_parsing_rgb_tensor,
             fake_b.data[0:1, :, :, :].cpu()),
            dim=3)
        show_image_tensor = torch.cat(
            (show_image_tensor_1[0:1, :, :, :], show_image_tensor_2), dim=2)
        test_list = [('a-b-fake_b', tensor2im(show_image_tensor[0])),
                     ('fake_image', util.tensor2im(fake_b.data[0])),
                     ('b_image', util.tensor2im(b_image_tensor[0]))]

        ### save image
        visuals = OrderedDict(test_list)
        visualizer.save_images(webpage, visuals, a_jpg_path[0], b_jpg_path[0])

        print('[%s]process image %s' % (i, a_jpg_path[0]))
        ### 从零开始为啥只有12779张?本来12800的!难道有11pair是重复的?检查pair文件。。
        ### 奇怪哦!难道要12800 + 21

    webpage.save()

    image_dir = webpage.get_image_dir()
    print image_dir
Exemple #21
0
        eye_gaze_video = Variable(data['eye_video']).view(
            1, -1, 3, height, width)
        input_A = torch.cat([nmfc_video, eye_gaze_video], dim=2)
    img_path = data['A_paths']

    print('Processing NMFC image %s' % img_path[-1])

    generated = modelG.inference(input_A, rgb_video)

    if opt.time_fwd_pass:
        end.record()
        # Waits for everything to finish running
        torch.cuda.synchronize()
        print('Forward pass time: %.6f' % start.elapsed_time(end))

    fake_frame = util.tensor2im(generated[0].data[0])
    rgb_frame = util.tensor2im(rgb_video[0, -1])
    nmfc_frame = util.tensor2im(nmfc_video[0, -1], normalize=False)
    if not opt.no_eye_gaze:
        eye_gaze_frame = util.tensor2im(eye_gaze_video[0, -1], normalize=False)

    visual_list = [('real', rgb_frame), ('fake', fake_frame),
                   ('nmfc', nmfc_frame)]
    if not opt.no_eye_gaze:
        visual_list += [('eye_gaze', eye_gaze_frame)]

    # If in self reenactment mode, compute pixel error and heatmap.
    if not opt.do_reenactment:
        total_distance, total_pixels, heatmap = util.get_pixel_distance(
            rgb_frame, fake_frame, total_distance, total_pixels)
        mtotal_distance, mtotal_pixels, mheatmap = util.get_pixel_distance(
Exemple #22
0
dataset = CreateDataset(opt)

# test
model = create_model(opt)
if opt.verbose:
    print(model)

# test whole video sequence
# 20181009: do we use first frame as input?

data = dataset[0]
if opt.use_first_frame:
    prev_frame = data['image']
    start_from = 1
    from skimage.io import imsave
    imsave('results/ref.png', util.tensor2im(prev_frame))
    generated = [util.tensor2im(prev_frame)]
else:
    prev_frame = torch.zeros_like(data['image'])
    start_from = 0
    generated = []

from skimage.io import imsave
for i in tqdm(range(start_from, dataset.clip_length)):
    label = data['label'][i:i + 1]
    #print(label.shape)
    inst = None if opt.no_instance else data['inst'][i:i + 1]

    cur_frame = model.inference(label, inst, torch.unsqueeze(prev_frame,
                                                             dim=0))
    prev_frame = cur_frame.data[0]
Exemple #23
0
    # test with eval mode. This only affects layers like batchnorm and dropout.
    # For [pix2pix]: we use batchnorm and dropout in the original pix2pix. You can experiment it with and without eval() mode.
    # For [CycleGAN]: It should not affect CycleGAN as CycleGAN uses instancenorm without dropout.
    if opt.eval:
        model.eval()
    for i, data in enumerate(dataset):
        #if i >= opt.num_test:  # only apply our model to opt.num_test images.
        #    break
        model.set_input(data)  # unpack data from data loader
        model.test()  # run inference
        fake_B = getattr(model, 'fake_B')  # (1,1,64,128,128)
        # fake_B = torch.squeeze(fake_B)
        # fake_B_img = fake_B.cpu()
        # img = torchvision.transforms.ToPILImage()(fake_B_img[0].unsqueeze(0))
        patient_id = data['B_paths'][0].split('/')[-1]
        save_dir = os.path.join(opt.results_dir, patient_id)
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        for idx in range(fake_B.shape[2]):
            save_path = os.path.join(save_dir, str(idx) + '.png')
            im = util.tensor2im(fake_B[:, :, idx, :, :])
            util.save_image(im, save_path)

        #img.save('test.png')
        '''visuals = model.get_current_visuals()  # get image results
        img_path = model.get_image_paths()     # get image paths
        if i % 5 == 0:  # save images to an HTML file
            print('processing (%04d)-th image... %s' % (i, img_path))
        save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize) comment by cz'''
    # webpage.save()  # save the HTML
Exemple #24
0
def save_img(tensor, save_name, save_path):
    im = tensor2im(tensor)
    if not os.path.exists(save_path):
        os.mkdir(save_path)
    save_image(im, os.path.join(save_path,save_name))
Exemple #25
0
 ### print out errors
 if total_steps % print_freq == print_delta:
     errors = {
         k: v.data.item() if not isinstance(v, int) else v
         for k, v in loss_dict.items()
     }
     t = (time.time() - iter_start_time) / print_freq
     visualizer.print_current_errors(epoch, epoch_iter, errors, t)
     visualizer.plot_current_errors(errors, total_steps)
     # call(["nvidia-smi", "--format=csv", "--query-gpu=memory.used,memory.free"])
     ### display output images
     if save_fake:
         visuals = OrderedDict([('synthesized_image',
                                 cv2.cvtColor(generated,
                                              cv2.COLOR_BGR2RGB)),
                                ('real_image', util.tensor2im(image))])
         cv2.imwrite(
             'trainimage/synthesized_image_e' + str(epoch) + '_i' +
             str(i) + '.jpg', visuals["synthesized_image"])
         cv2.imwrite(
             'trainimage/real_image_e' + str(epoch) + '_i' + str(i) +
             '.jpg', visuals["real_image"])
         visualizer.display_current_results(visuals, epoch, total_steps)
     if total_steps % save_latest_freq == save_delta:
         print('saving the latest model (epoch %d, total_steps %d)' %
               (epoch, total_steps))
         model.save(epoch)
         np.savetxt(iter_path, (epoch, epoch_iter)[0],
                    delimiter=',',
                    fmt='%d')
     if epoch_iter >= dataset_size:
 def get_current_visuals(self):
     real_A = util.tensor2im(self.real_A.data)
     fake_B = util.tensor2im(self.fake_B.data)
     real_B = util.tensor2im(self.real_B.data)
     return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('real_B', real_B)])
Exemple #27
0
 def get_current_visuals(self):
     input_A = util.tensor2im(self.input_A.data)
     # pred_B = util.tensor2im(self.pred_B.data)
     # input_B = util.tensor2im(self.input_B.data)
     return OrderedDict([('input_A', input_A)])
dataset = CreateDataset(opt)

# test
model = create_model(opt)
if opt.verbose:
    print(model)

# test whole video sequence
# 20181009: do we use first frame as input?

data = dataset[0]
if opt.use_first_frame:
    prev_frame = data['image']
    start_from = 1
    from skimage.io import imsave
    imsave('results/ref.png', util.tensor2im(prev_frame))
    generated = [util.tensor2im(prev_frame)]
else:
    prev_frame = torch.zeros_like(data['image'])
    start_from = 0
    generated = []

from skimage.io import imsave
for i in tqdm(range(start_from, dataset.clip_length)):
    label = data['label'][i:i+1]
    #print(label.shape)
    inst = None if opt.no_instance else data['inst'][i:i+1]

    cur_frame = model.inference(label, inst, torch.unsqueeze(prev_frame, dim=0))
    prev_frame = cur_frame.data[0]
    
Exemple #29
0
    elif opt.data_type == 8:
        data['label'] = data['label'].uint8()
        data['inst'] = data['inst'].uint8()
    if opt.export_onnx:
        print("Exporting to ONNX: ", opt.export_onnx)
        assert opt.export_onnx.endswith(
            "onnx"), "Export model file should end with .onnx"
        torch.onnx.export(model, [data['label'], data['inst']],
                          opt.export_onnx,
                          verbose=True)
        exit(0)
    minibatch = 1
    if opt.engine:
        generated = run_trt_engine(opt.engine, minibatch,
                                   [data['label'], data['inst']])
    elif opt.onnx:
        generated = run_onnx(opt.onnx, opt.data_type, minibatch,
                             [data['label'], data['inst']])
    else:
        generated = model.inference(data['label'], data['inst'])

    visuals = OrderedDict([
        ('input_label', util.tensor2label(data['label'][0], opt.label_nc)),
        ('synthesized_image', util.tensor2im(generated.data[0]))
    ])
    img_path = data['path']
    print('process image... %s' % img_path)
    visualizer.save_images(webpage, visuals, img_path)

webpage.save()
Exemple #30
0
data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()
model = create_model_fullts(opt)
visualizer = Visualizer(opt)
# create website
web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch))
webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch))
# test
unset = True
print('#testing images = %d' % len(data_loader))

for i, data in enumerate(dataset):
    if i >= opt.how_many:
        break

    if unset: #no previous results, condition on zero image
      previous_cond = torch.zeros(data['label'].size())
      unset = False

    #generated = model.inference(data['label'], previous_cond, data['face_coords'])
    generated = model.inference(data['label'], previous_cond)

    previous_cond = generated.data

    visuals = OrderedDict([('synthesized_image', util.tensor2im(generated.data[0]))])
    img_path = data['path']
    print('process image... %s' % img_path)
    visualizer.save_images(webpage, visuals, img_path)

webpage.save()
 def get_current_visuals(self):
     real_A = util.tensor2im(self.real_A.data)
     fake_B = util.tensor2im(self.fake_B.data)
     real_B = util.tensor2im(self.real_B.data)
     return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('real_B', real_B)])
import util.util as util
from torch.autograd import Variable
import torch.nn as nn

opt = TrainOptions().parse()
opt.nThreads = 1
opt.batchSize = 1
opt.serial_batches = True
opt.no_flip = True
opt.instance_feat = True

name = 'features'
save_path = os.path.join(opt.checkpoints_dir, opt.name)

############ Initialize #########
data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()
dataset_size = len(data_loader)
model = create_model(opt)
util.mkdirs(os.path.join(opt.dataroot, opt.phase + '_feat'))

######## Save precomputed feature maps for 1024p training #######
for i, data in enumerate(dataset):
    print('%d / %d images' % (i + 1, dataset_size))
    feat_map = model.module.netE.forward(
        Variable(data['image'].cuda(), volatile=True), data['inst'].cuda())
    feat_map = nn.Upsample(scale_factor=2, mode='nearest')(feat_map)
    image_numpy = util.tensor2im(feat_map.data[0])
    save_path = data['path'][0].replace('/train_label/', '/train_feat/')
    util.save_image(image_numpy, save_path)
    def get_current_visuals(self):
        real_A0 = util.tensor2im(self.input_A0)
        real_A1 = util.tensor2im(self.input_A1)
        real_A2 = util.tensor2im(self.input_A2)

        fake_B0 = util.tensor2im(self.fake_B0)
        fake_B1 = util.tensor2im(self.fake_B1)
        fake_B2 = util.tensor2im(self.fake_B2)

        rec_A = util.tensor2im(self.rec_A)

        real_B0 = util.tensor2im(self.input_B0)
        real_B1 = util.tensor2im(self.input_B1)
        real_B2 = util.tensor2im(self.input_B2)

        fake_A0 = util.tensor2im(self.fake_A0)
        fake_A1 = util.tensor2im(self.fake_A1)
        fake_A2 = util.tensor2im(self.fake_A2)

        rec_B = util.tensor2im(self.rec_B)

        pred_A2 = util.tensor2im(self.pred_A2)
        pred_B2 = util.tensor2im(self.pred_B2)

        ret_visuals = OrderedDict([('real_A0', real_A0), ('fake_B0', fake_B0),
                                   ('real_A1', real_A1), ('fake_B1', fake_B1),
                                   ('fake_B2', fake_B2), ('rec_A', rec_A),
                                   ('real_A2', real_A2), ('real_B0', real_B0),
                                   ('fake_A0', fake_A0), ('real_B1', real_B1),
                                   ('fake_A1', fake_A1), ('fake_A2', fake_A2),
                                   ('rec_B', rec_B), ('real_B2', real_B2),
                                   ('real_A2', real_A2), ('pred_A2', pred_A2),
                                   ('real_B2', real_B2), ('pred_B2', pred_B2)])
        if self.opt.isTrain and self.opt.identity > 0.0:
            ret_visuals['idt_A'] = util.tensor2im(self.idt_A)
            ret_visuals['idt_B'] = util.tensor2im(self.idt_B)
        return ret_visuals
                'warping_ref_lmark': data['ref_label'][:, 0]
            })

        img_path = data['path']
        data_list = [data['tgt_label'], data['tgt_image'], None, None, None, None, \
                    data['ref_label'], data['ref_image'], \
                    data['warping_ref_lmark'], \
                    data['warping_ref'], \
                    data['ani_lmark'].squeeze(1) if opt.warp_ani else None, \
                    data['ani_image'].squeeze(1) if opt.warp_ani else None, \
                    None, None, None]
        synthesized_image, fake_raw_img, warped_img, flow, weight, _, _, _, _, _ = model(
            data_list, ref_idx_fix=ref_idx_fix)

        visuals = [
            util.tensor2im(data['tgt_label']), \
            util.tensor2im(data['tgt_image']), \
            util.tensor2im(synthesized_image), \
            util.tensor2im(fake_raw_img), \
            util.tensor2im(warped_img[0]), \
            util.tensor2im(weight[0]), \
            util.tensor2im(warped_img[2]), \
            util.tensor2im(weight[2])
        ]
        compare_image = np.hstack([v for v in visuals if v is not None])

        synthesized_image = util.tensor2im(synthesized_image)
        tgt_image = util.tensor2im(data['tgt_image'])

        # img_id = "{}_{}_{}".format(img_path[0].split('/')[-3], img_path[0].split('/')[-2], img_path[0].split('/')[-1][:-4])
        img_id = file[:-4]
Exemple #35
0
    def get_current_visuals(self):
        real_A = util.tensor2im(self.real_A.data)
        fake_B = util.tensor2im(self.fake_B.data)
        if self.opt.skip > 0:
            latent_real_A = util.tensor2im(self.latent_real_A.data)

        real_B = util.tensor2im(self.real_B.data)
        fake_A = util.tensor2im(self.fake_A.data)

        if self.opt.identity > 0:
            idt_A = util.tensor2im(self.idt_A.data)
            idt_B = util.tensor2im(self.idt_B.data)
            if self.opt.lambda_A > 0.0:
                rec_A = util.tensor2im(self.rec_A.data)
                rec_B = util.tensor2im(self.rec_B.data)
                if self.opt.skip > 0:
                    latent_fake_A = util.tensor2im(self.latent_fake_A.data)
                    return OrderedDict([('real_A', real_A), ('fake_B', fake_B),
                                        ('latent_real_A', latent_real_A),
                                        ('rec_A', rec_A), ('real_B', real_B),
                                        ('fake_A', fake_A), ('rec_B', rec_B),
                                        ('latent_fake_A', latent_fake_A),
                                        ("idt_A", idt_A), ("idt_B", idt_B)])
                else:
                    return OrderedDict([('real_A', real_A), ('fake_B', fake_B),
                                        ('rec_A', rec_A), ('real_B', real_B),
                                        ('fake_A', fake_A), ('rec_B', rec_B),
                                        ("idt_A", idt_A), ("idt_B", idt_B)])
            else:
                if self.opt.skip > 0:
                    return OrderedDict([('real_A', real_A), ('fake_B', fake_B),
                                        ('latent_real_A', latent_real_A),
                                        ('real_B', real_B), ('fake_A', fake_A),
                                        ("idt_A", idt_A), ("idt_B", idt_B)])
                else:
                    return OrderedDict([('real_A', real_A), ('fake_B', fake_B),
                                        ('real_B', real_B), ('fake_A', fake_A),
                                        ("idt_A", idt_A), ("idt_B", idt_B)])
        else:
            if self.opt.lambda_A > 0.0:
                rec_A = util.tensor2im(self.rec_A.data)
                rec_B = util.tensor2im(self.rec_B.data)
                if self.opt.skip > 0:
                    latent_fake_A = util.tensor2im(self.latent_fake_A.data)
                    return OrderedDict([('real_A', real_A), ('fake_B', fake_B),
                                        ('latent_real_A', latent_real_A),
                                        ('rec_A', rec_A), ('real_B', real_B),
                                        ('fake_A', fake_A), ('rec_B', rec_B),
                                        ('latent_fake_A', latent_fake_A)])
                else:
                    return OrderedDict([('real_A', real_A), ('fake_B', fake_B),
                                        ('rec_A', rec_A), ('real_B', real_B),
                                        ('fake_A', fake_A), ('rec_B', rec_B)])
            else:
                if self.opt.skip > 0:
                    return OrderedDict([('real_A', real_A), ('fake_B', fake_B),
                                        ('latent_real_A', latent_real_A),
                                        ('real_B', real_B),
                                        ('fake_A', fake_A)])
                else:
                    return OrderedDict([('real_A', real_A), ('fake_B', fake_B),
                                        ('real_B', real_B),
                                        ('fake_A', fake_A)])
def train(model, criterion, train_set, val_set, opt, labels=None):
    # define web visualizer using visdom
    webvis = WebVisualizer(opt)
    
    # modify learning rate of last layer
    finetune_params = modify_last_layer_lr(model.named_parameters(), 
                                            opt.lr, opt.lr_mult_w, opt.lr_mult_b)
    # define optimizer
    optimizer = optim.SGD(finetune_params, 
                          opt.lr, 
                          momentum=opt.momentum, 
                          weight_decay=opt.weight_decay)
    # define laerning rate scheluer
    scheduler = optim.lr_scheduler.StepLR(optimizer, 
                                          step_size=opt.lr_decay_in_epoch,
                                          gamma=opt.gamma)
    if labels is not None:
        rid2name, id2rid = labels
    
    # record forward and backward times 
    train_batch_num = len(train_set)
    total_batch_iter = 0
    logging.info("####################Train Model###################")
    for epoch in range(opt.sum_epoch):
        epoch_start_t = time.time()
        epoch_batch_iter = 0
        logging.info('Begin of epoch %d' %(epoch))
        for i, data in enumerate(train_set):
            iter_start_t = time.time()
            # train 
            inputs, targets = data
            output, loss, loss_list = forward_batch(model, criterion, inputs, targets, opt, "Train")
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
           
            webvis.reset()
            epoch_batch_iter += 1
            total_batch_iter += 1

            # display train loss and accuracy
            if total_batch_iter % opt.display_train_freq == 0:
                # accuracy
                batch_accuracy = calc_accuracy(output, targets, opt.score_thres, opt.top_k) 
                util.print_loss(loss_list, "Train", epoch, total_batch_iter)
                util.print_accuracy(batch_accuracy, "Train", epoch, total_batch_iter)
                if opt.display_id > 0:
                    x_axis = epoch + float(epoch_batch_iter)/train_batch_num
                    # TODO support accuracy visualization of multiple top_k
                    plot_accuracy = [batch_accuracy[i][opt.top_k[0]] for i in range(len(batch_accuracy)) ]
                    accuracy_list = [item["ratio"] for item in plot_accuracy]
                    webvis.plot_points(x_axis, loss_list, "Loss", "Train")
                    webvis.plot_points(x_axis, accuracy_list, "Accuracy", "Train")
            
            # display train data 
            if total_batch_iter % opt.display_data_freq == 0:
                image_list = list()
                show_image_num = int(np.ceil(opt.display_image_ratio * inputs.size()[0]))
                for index in range(show_image_num): 
                    input_im = util.tensor2im(inputs[index], opt.mean, opt.std)
                    class_label = "Image_" + str(index) 
                    if labels is not None:
                        target_ids = [targets[i][index] for i in range(opt.class_num)]
                        rids = [id2rid[j][k] for j,k in enumerate(target_ids)]
                        class_label += "_"
                        class_label += "#".join([rid2name[j][k] for j,k in enumerate(rids)])
                    image_list.append((class_label, input_im))
                image_dict = OrderedDict(image_list)
                save_result = total_batch_iter % opt.update_html_freq
                webvis.plot_images(image_dict, opt.display_id + 2*opt.class_num, epoch, save_result)
            
            # validate and display validate loss and accuracy
            if len(val_set) > 0  and total_batch_iter % opt.display_validate_freq == 0:
                val_accuracy, val_loss = validate(model, criterion, val_set, opt)
                x_axis = epoch + float(epoch_batch_iter)/train_batch_num
                accuracy_list = [val_accuracy[i][opt.top_k[0]]["ratio"] for i in range(len(val_accuracy))]
                util.print_loss(val_loss, "Validate", epoch, total_batch_iter)
                util.print_accuracy(val_accuracy, "Validate", epoch, total_batch_iter)
                if opt.display_id > 0:
                    webvis.plot_points(x_axis, val_loss, "Loss", "Validate")
                    webvis.plot_points(x_axis, accuracy_list, "Accuracy", "Validate")

            # save snapshot 
            if total_batch_iter % opt.save_batch_iter_freq == 0:
                logging.info("saving the latest model (epoch %d, total_batch_iter %d)" %(epoch, total_batch_iter))
                save_model(model, opt, epoch)
                # TODO snapshot loss and accuracy
            
        logging.info('End of epoch %d / %d \t Time Taken: %d sec' %
              (epoch, opt.sum_epoch, time.time() - epoch_start_t))
        
        if epoch % opt.save_epoch_freq == 0:
            logging.info('saving the model at the end of epoch %d, iters %d' %(epoch+1, total_batch_iter))
            save_model(model, opt, epoch+1) 

        # adjust learning rate 
        scheduler.step()
        lr = optimizer.param_groups[0]['lr'] 
        logging.info('learning rate = %.7f epoch = %d' %(lr,epoch)) 
    logging.info("--------Optimization Done--------")
Exemple #37
0
            break

        data["dp_target"] = data["dp_target"].permute(1, 0, 2, 3, 4)
        data["grid"] = data["grid"].permute(1, 0, 2, 3, 4)
        data["grid_source"] = data["grid_source"].permute(1, 0, 2, 3, 4)

        generated_video = []
        real_video = []

        generated = model.inference(data['dp_target'][0], data['source_frame'],
                                    data['source_frame'],
                                    data['grid_source'][0],
                                    data['grid_source'][0])

        stacked_images = np.hstack(
            util.tensor2im(generated.data[i]) for i in range(0, 5))
        stacked_images_source = np.hstack(
            util.tensor2im(data['source_frame'][i]) for i in range(0, 5))
        generated_video.append(stacked_images)
        visuals = OrderedDict([('source_images', stacked_images_source),
                               ('synthesized_image', stacked_images)])
        img_path = str(0)
        print('process image... %s' % img_path)
        visualizer.save_images(webpage, visuals, img_path)
        visualizer.display_current_results(visuals, 100, 12345)

        for i in range(1, data["dp_target"].shape[0]):

            generated = model.inference(data['dp_target'][i],
                                        data['source_frame'], generated,
                                        data['grid_source'][i],
Exemple #38
0
        data["B"] = torch.cat(B, 1)
        data["hint_B"] = torch.cat(HintB, 1)
        data["mask_B"] = torch.cat(MaskB, 1)

        # with no points
        for (pp, sample_p) in enumerate(sample_ps):
            img_path = [('%08d_%.3f' % (i, sample_p)).replace('.', 'p')]
            # data = util.get_colorization_data(data_raw[0], opt, ab_thresh=0., p=sample_p)

            model.set_input(data)
            model.test(True)  # True means that losses will be computed
            visuals = util.get_subset_dict(model.get_current_visuals(),
                                           to_visualize)
            vid = []
            for frames in range(0, visuals['fake_reg'].shape[1], 3):
                gray = util.tensor2im(visuals['gray'][:, frames:frames + 3])
                real = util.tensor2im(visuals['real'][:, frames:frames + 3])
                generated = util.tensor2im(
                    visuals['fake_reg'][:, frames:frames + 3])
                vid.append(np.vstack((gray, real, generated)))
                psnrs[i, pp] += util.calculate_psnr_np(
                    real, generated) / (visuals['fake_reg'].shape[1] / 3)
                #calculate_smoothness_here
            vid = np.array(vid)
            arr2gif(vid, "vid_{}".format(i))
            entrs[i, pp] = model.get_current_losses()['G_entr']

            # save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize)
            # print("")
        if i % 5 == 0:
            print('processing (%04d)-th image... %s' % (i, img_path))
Exemple #39
0
for i, data in enumerate(dataset):
    if i >= opt.how_many:
        break
    if opt.data_type == 16:
        data['label'] = data['label'].half()
        data['inst']  = data['inst'].half()
    elif opt.data_type == 8:
        data['label'] = data['label'].uint8()
        data['inst']  = data['inst'].uint8()
    if opt.export_onnx:
        print ("Exporting to ONNX: ", opt.export_onnx)
        assert opt.export_onnx.endswith("onnx"), "Export model file should end with .onnx"
        torch.onnx.export(model, [data['label'], data['inst']],
                          opt.export_onnx, verbose=True)
        exit(0)
    minibatch = 1 
    if opt.engine:
        generated = run_trt_engine(opt.engine, minibatch, [data['label'], data['inst']])
    elif opt.onnx:
        generated = run_onnx(opt.onnx, opt.data_type, minibatch, [data['label'], data['inst']])
    else:
        generated = model.inference(data['label'], data['inst'])
        
    visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)),
                           ('synthesized_image', util.tensor2im(generated.data[0]))])
    img_path = data['path']
    print('process image... %s' % img_path)
    visualizer.save_images(webpage, visuals, img_path)

webpage.save()
Exemple #40
0
def train(opt):
    iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')

    if opt.continue_train:
        if opt.which_epoch == 'latest':
            try:
                start_epoch, epoch_iter = np.loadtxt(iter_path,
                                                     delimiter=',',
                                                     dtype=int)
            except:
                start_epoch, epoch_iter = 1, 0
        else:
            start_epoch, epoch_iter = int(opt.which_epoch), 0

        print('Resuming from epoch %d at iteration %d' %
              (start_epoch, epoch_iter))
        for update_point in opt.decay_epochs:
            if start_epoch < update_point:
                break

            opt.lr *= opt.decay_gamma
    else:
        start_epoch, epoch_iter = 0, 0

    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('#training images = %d' % dataset_size)

    model = create_model(opt)
    visualizer = Visualizer(opt)

    total_steps = (start_epoch) * dataset_size + epoch_iter

    display_delta = total_steps % opt.display_freq
    print_delta = total_steps % opt.print_freq
    save_delta = total_steps % opt.save_latest_freq
    bSize = opt.batchSize

    #in case there's no display sample one image from each class to test after every epoch
    if opt.display_id == 0:
        dataset.dataset.set_sample_mode(True)
        dataset.num_workers = 1
        for i, data in enumerate(dataset):
            if i * opt.batchSize >= opt.numClasses:
                break
            if i == 0:
                sample_data = data
            else:
                for key, value in data.items():
                    if torch.is_tensor(data[key]):
                        sample_data[key] = torch.cat(
                            (sample_data[key], data[key]), 0)
                    else:
                        sample_data[key] = sample_data[key] + data[key]
        dataset.num_workers = opt.nThreads
        dataset.dataset.set_sample_mode(False)

    for epoch in range(start_epoch, opt.epochs):
        epoch_start_time = time.time()
        if epoch != start_epoch:
            epoch_iter = 0
        for i, data in enumerate(dataset, start=epoch_iter):
            iter_start_time = time.time()
            total_steps += opt.batchSize
            epoch_iter += opt.batchSize

            # whether to collect output images
            save_fake = (total_steps % opt.display_freq
                         == display_delta) and (opt.display_id > 0)

            ############## Network Pass ########################
            model.set_inputs(data)
            disc_losses = model.update_D()
            gen_losses, gen_in, gen_out, rec_out, cyc_out = model.update_G(
                infer=save_fake)
            loss_dict = dict(gen_losses, **disc_losses)
            ##################################################

            ############## Display results and errors ##########
            ### print out errors
            if total_steps % opt.print_freq == print_delta:
                errors = {
                    k: v.item()
                    if not (isinstance(v, float) or isinstance(v, int)) else v
                    for k, v in loss_dict.items()
                }
                t = (time.time() - iter_start_time) / opt.batchSize
                visualizer.print_current_errors(epoch + 1, epoch_iter, errors,
                                                t)
                if opt.display_id > 0:
                    visualizer.plot_current_errors(
                        epoch,
                        float(epoch_iter) / dataset_size, opt, errors)

            ### display output images
            if save_fake and opt.display_id > 0:
                class_a_suffix = ' class {}'.format(data['A_class'][0])
                class_b_suffix = ' class {}'.format(data['B_class'][0])
                classes = None

                visuals = OrderedDict()
                visuals_A = OrderedDict([('real image' + class_a_suffix,
                                          util.tensor2im(gen_in.data[0]))])
                visuals_B = OrderedDict([('real image' + class_b_suffix,
                                          util.tensor2im(gen_in.data[bSize]))])

                A_out_vis = OrderedDict([('synthesized image' + class_b_suffix,
                                          util.tensor2im(gen_out.data[0]))])
                B_out_vis = OrderedDict([('synthesized image' + class_a_suffix,
                                          util.tensor2im(gen_out.data[bSize]))
                                         ])
                if opt.lambda_rec > 0:
                    A_out_vis.update([('reconstructed image' + class_a_suffix,
                                       util.tensor2im(rec_out.data[0]))])
                    B_out_vis.update([('reconstructed image' + class_b_suffix,
                                       util.tensor2im(rec_out.data[bSize]))])
                if opt.lambda_cyc > 0:
                    A_out_vis.update([('cycled image' + class_a_suffix,
                                       util.tensor2im(cyc_out.data[0]))])
                    B_out_vis.update([('cycled image' + class_b_suffix,
                                       util.tensor2im(cyc_out.data[bSize]))])

                visuals_A.update(A_out_vis)
                visuals_B.update(B_out_vis)
                visuals.update(visuals_A)
                visuals.update(visuals_B)

                ncols = len(visuals_A)
                visualizer.display_current_results(visuals, epoch, classes,
                                                   ncols)

            ### save latest model
            if total_steps % opt.save_latest_freq == save_delta:
                print('saving the latest model (epoch %d, total_steps %d)' %
                      (epoch + 1, total_steps))
                model.save('latest')
                np.savetxt(iter_path, (epoch, epoch_iter),
                           delimiter=',',
                           fmt='%d')
                if opt.display_id == 0:
                    model.eval()
                    visuals = model.inference(sample_data)
                    visualizer.save_matrix_image(visuals, 'latest')
                    model.train()

        # end of epoch
        iter_end_time = time.time()
        print('End of epoch %d / %d \t Time Taken: %d sec' %
              (epoch + 1, opt.epochs, time.time() - epoch_start_time))

        ### save model for this epoch
        if (epoch + 1) % opt.save_epoch_freq == 0:
            print('saving the model at the end of epoch %d, iters %d' %
                  (epoch + 1, total_steps))
            model.save('latest')
            model.save(epoch + 1)
            np.savetxt(iter_path, (epoch + 1, 0), delimiter=',', fmt='%d')
            if opt.display_id == 0:
                model.eval()
                visuals = model.inference(sample_data)
                visualizer.save_matrix_image(visuals, epoch + 1)
                model.train()

        ### multiply learning rate by opt.decay_gamma after certain iterations
        if (epoch + 1) in opt.decay_epochs:
            model.update_learning_rate()
Exemple #41
0
    if opt.export_onnx:
        print("Exporting to ONNX: ", opt.export_onnx)
        assert opt.export_onnx.endswith(
            "onnx"), "Export model file should end with .onnx"
        torch.onnx.export(model, [data['label'], data['inst']],
                          opt.export_onnx,
                          verbose=True)
        exit(0)
    minibatch = 1
    if opt.engine:
        generated = run_trt_engine(opt.engine, minibatch,
                                   [data['label'], data['inst']])
    elif opt.onnx:
        generated = run_onnx(opt.onnx, opt.data_type, minibatch,
                             [data['label'], data['inst']])
    else:
        generated = model.inference(data['A'], data['B'], data['B2'])

    visuals = OrderedDict([('input_label', util.tensor2label(data['A'][0], 0)),
                           ('real_image', util.tensor2im(data['A2'][0])),
                           ('synthesized_image',
                            util.tensor2im(generated.data[0])),
                           ('B', util.tensor2label(data['B'][0], 0)),
                           ('B2', util.tensor2im(data['B2'][0]))])
    img_path = data['path']
    img_path[0] = str(i)
    print('process image... %s' % img_path)
    visualizer.save_images(webpage, visuals, img_path)

webpage.save()
image_files = fileList(in_dir)
import pdb
pdb.set_trace()

# Create dataset and model
dataset = create_dataset(opt)
model = create_model(opt)
model.setup(opt)

model.eval()
for i, data in enumerate(dataset):
    import pdb
    pdb.set_trace()
    model.set_input(data)
    model.test()
    faked = model.get_current_visuals()['fake']
    faked_img = tensor2im(faked)
    img_path = model.image_paths

    print("Writing {}".format(img_path))
    plt.imsave('test.png', faked_img)
'''
model.set_input(data)
model.test()
visuals = model.get_current_visuals()
fake = visuals['fake']


'''
	def get_current_visuals(self):
		real_A = util.tensor2im(self.real_A.data)
		fake_B = util.tensor2im(self.fake_B.data)
		real_B = util.tensor2im(self.real_B.data)
		return OrderedDict([('Blurred_Train', real_A), ('Restored_Train', fake_B), ('Sharp_Train', real_B)])
Exemple #44
0
        ############## Display results and errors ##########
        ### print out errors
        if total_steps % opt.print_freq == print_delta:
            errors = {k: v.data.item() if not isinstance(v, int) else v for k, v in loss_dict.items()}
            t = (time.time() - iter_start_time) / opt.batchSize
            visualizer.print_current_errors(epoch, epoch_iter, errors, t)
            visualizer.plot_current_errors(errors, total_steps)

        ### display output images
        if save_fake:
            label = torch.squeeze(data['label'])
            image = torch.squeeze(data['image'])
            label = torch.cat((label[0], label[1]), dim=2)
            image = torch.cat((image[0], image[1]), dim=2)
            visuals = OrderedDict([('input_label', util.tensor2label(label, opt.label_nc)),
                                   ('synthesized_image', util.tensor2im(generated)),
                                   ('real_image', util.tensor2im(image))])
            visualizer.display_current_results(visuals, epoch, total_steps)

        ### save latest model
        if total_steps % opt.save_latest_freq == save_delta:
            print('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps))
            model.module.save('latest')            
            np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d')

        if epoch_iter >= dataset_size:
            break
       
    # end of epoch 
    iter_end_time = time.time()
    print('End of epoch %d / %d \t Time Taken: %d sec' %
Exemple #45
0
        ############## Display results and errors ##########
        ### print out errors
        if total_steps % opt.print_freq == print_delta:
            errors = {k: v.data.item() if not isinstance(v, int) else v for k, v in loss_dict.items()}            
            t = (time.time() - iter_start_time) / opt.print_freq
            visualizer.print_current_errors(epoch, epoch_iter, errors, t,total)
            visualizer.plot_current_errors(errors, total_steps,total)
            call(["nvidia-smi", "--format=csv", "--query-gpu=memory.used,memory.free"])
        ### display output images
        
        race_str = data_loader.dataset.getLabelEncoder().inverse_transform(Variable(data['race']))[0]
        img_num = str(data['img_num'][0].item())

        if save_fake:
            if(opt.deform):
                visuals = OrderedDict([('input_label', util.tensor2im(data['sketch'][0])),
                                       ('synthesized_image', util.tensor2im(result['fake_image'][0])),
                                       ('deformed_image', util.tensor2im(result['fake_image_deform'][0])),
                                       ('real_image', util.tensor2im(data['photo'][0]))])
            else:
                visuals = OrderedDict([('input_label', util.tensor2im(data['sketch'][0])),
                                       ('synthesized_image', util.tensor2im(result['fake_image'][0])),                         
                                       ('real_image', util.tensor2im(data['photo'][0]))])
    
            visualizer.display_current_results(visuals, epoch, total_steps, race_str, img_num)
        
        ### save latest model
        if total_steps % opt.save_latest_freq == save_delta:
            print('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps))
            model.module.save('latest')            
            np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d')    
Exemple #46
0
dataset = CreateDataset(opt)

# test
model = create_model(opt)
if opt.verbose:
    print(model)

# test whole video sequence
# 20181009: do we use first frame as input?

data = dataset[0]
if opt.use_first_frame:
    prev_frame = data['image']
    start_from = 1
    from skimage.io import imsave
    imsave('results/ref.png', util.tensor2im(prev_frame))
    generated = [util.tensor2im(prev_frame)]
else:
    prev_frame = torch.zeros_like(data['image'])
    start_from = 0
    generated = []

from skimage.io import imsave
for i in tqdm(range(start_from, dataset.clip_length)):
    label = data['label'][i:i+1]
    #print(label.shape)
    inst = None if opt.no_instance else data['inst'][i:i+1]

    cur_frame = model.inference(label, inst, torch.unsqueeze(prev_frame, dim=0))
    prev_frame = cur_frame.data[0]
Exemple #47
0
        model.module.optimizer_D.step()

        #call(["nvidia-smi", "--format=csv", "--query-gpu=memory.used,memory.free"]) 

        ############## Display results and errors ##########
        ### print out errors
        if total_steps % opt.print_freq == print_delta:
            errors = {k: v.data[0] if not isinstance(v, int) else v for k, v in loss_dict.items()}
            t = (time.time() - iter_start_time) / opt.batchSize
            visualizer.print_current_errors(epoch, epoch_iter, errors, t)
            visualizer.plot_current_errors(errors, total_steps)

        ### display output images
        if save_fake:
            visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)),
                                   ('synthesized_image', util.tensor2im(generated.data[0])),
                                   ('real_image', util.tensor2im(data['image'][0]))])
            visualizer.display_current_results(visuals, epoch, total_steps)

        ### save latest model
        if total_steps % opt.save_latest_freq == save_delta:
            print('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps))
            model.module.save('latest')            
            np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d')
       
    # end of epoch 
    iter_end_time = time.time()
    print('End of epoch %d / %d \t Time Taken: %d sec' %
          (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))

    ### save model for this epoch
Exemple #48
0
def train():
    opt = TrainOptions().parse()
    if opt.debug:
        opt.display_freq = 1
        opt.print_freq = 1
        opt.nThreads = 1

    ### initialize dataset
    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    if opt.dataset_mode == 'pose':
        print('#training frames = %d' % dataset_size)
    else:
        print('#training videos = %d' % dataset_size)

    ### initialize models
    modelG, modelD, flowNet = create_model(opt)
    visualizer = Visualizer(opt)

    iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')
    ### if continue training, recover previous states
    if opt.continue_train:
        try:
            start_epoch, epoch_iter = np.loadtxt(iter_path,
                                                 delimiter=',',
                                                 dtype=int)
        except:
            start_epoch, epoch_iter = 1, 0
        print('Resuming from epoch %d at iteration %d' %
              (start_epoch, epoch_iter))
        if start_epoch > opt.niter:
            modelG.module.update_learning_rate(start_epoch - 1)
            modelD.module.update_learning_rate(start_epoch - 1)
        if (opt.n_scales_spatial > 1) and (opt.niter_fix_global != 0) and (
                start_epoch > opt.niter_fix_global):
            modelG.module.update_fixed_params()
        if start_epoch > opt.niter_step:
            data_loader.dataset.update_training_batch(
                (start_epoch - 1) // opt.niter_step)
            modelG.module.update_training_batch(
                (start_epoch - 1) // opt.niter_step)
    else:
        start_epoch, epoch_iter = 1, 0

    ### set parameters
    n_gpus = opt.n_gpus_gen // opt.batchSize  # number of gpus used for generator for each batch
    tG, tD = opt.n_frames_G, opt.n_frames_D
    tDB = tD * opt.output_nc
    s_scales = opt.n_scales_spatial
    t_scales = opt.n_scales_temporal
    input_nc = 1 if opt.label_nc != 0 else opt.input_nc
    output_nc = opt.output_nc

    opt.print_freq = lcm(opt.print_freq, opt.batchSize)
    total_steps = (start_epoch - 1) * dataset_size + epoch_iter
    total_steps = total_steps // opt.print_freq * opt.print_freq

    ### real training starts here
    for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
        for idx, data in enumerate(dataset, start=epoch_iter):
            if total_steps % opt.print_freq == 0:
                iter_start_time = time.time()
            total_steps += opt.batchSize
            epoch_iter += opt.batchSize

            # whether to collect output images
            save_fake = total_steps % opt.display_freq == 0

            _, n_frames_total, height, width = data['B'].size(
            )  # n_frames_total = n_frames_load * n_loadings + tG - 1
            n_frames_total = n_frames_total // opt.output_nc
            n_frames_load = opt.max_frames_per_gpu * n_gpus  # number of total frames loaded into GPU at a time for each batch
            n_frames_load = min(n_frames_load, n_frames_total - tG + 1)
            t_len = n_frames_load + tG - 1  # number of loaded frames plus previous frames

            fake_B_last = None  # the last generated frame from previous training batch (which becomes input to the next batch)
            real_B_all, fake_B_all, flow_ref_all, conf_ref_all = None, None, None, None  # all real/generated frames so far
            if opt.sparse_D:
                real_B_all, fake_B_all, flow_ref_all, conf_ref_all = [
                    None
                ] * t_scales, [None] * t_scales, [None] * t_scales, [
                    None
                ] * t_scales
            real_B_skipped, fake_B_skipped = [None] * t_scales, [
                None
            ] * t_scales  # temporally subsampled frames
            flow_ref_skipped, conf_ref_skipped = [None] * t_scales, [
                None
            ] * t_scales  # temporally subsampled flows

            for i in range(0, n_frames_total - t_len + 1, n_frames_load):
                # 5D tensor: batchSize, # of frames, # of channels, height, width
                input_A = Variable(
                    data['A'][:, i * input_nc:(i + t_len) * input_nc,
                              ...]).view(-1, t_len, input_nc, height, width)
                input_B = Variable(
                    data['B'][:, i * output_nc:(i + t_len) * output_nc,
                              ...]).view(-1, t_len, output_nc, height, width)
                inst_A = Variable(data['inst'][:, i:i + t_len, ...]).view(
                    -1, t_len, 1, height,
                    width) if len(data['inst'].size()) > 2 else None

                ###################################### Forward Pass ##########################
                ####### generator
                fake_B, fake_B_raw, flow, weight, real_A, real_Bp, fake_B_last = modelG(
                    input_A, input_B, inst_A, fake_B_last)

                if i == 0:
                    fake_B_first = fake_B[
                        0, 0]  # the first generated image in this sequence
                real_B_prev, real_B = real_Bp[:, :
                                              -1], real_Bp[:,
                                                           1:]  # the collection of previous and current real frames

                ####### discriminator
                ### individual frame discriminator
                flow_ref, conf_ref = flowNet(
                    real_B, real_B_prev)  # reference flows and confidences
                fake_B_prev = real_B_prev[:, 0:
                                          1] if fake_B_last is None else fake_B_last[
                                              0][:, -1:]
                if fake_B.size()[1] > 1:
                    fake_B_prev = torch.cat(
                        [fake_B_prev, fake_B[:, :-1].detach()], dim=1)

                losses = modelD(
                    0,
                    reshape([
                        real_B, fake_B, fake_B_raw, real_A, real_B_prev,
                        fake_B_prev, flow, weight, flow_ref, conf_ref
                    ]))
                losses = [
                    torch.mean(x) if x is not None else 0 for x in losses
                ]
                loss_dict = dict(zip(modelD.module.loss_names, losses))

                ### temporal discriminator
                loss_dict_T = []
                # get skipped frames for each temporal scale
                if t_scales > 0:
                    if opt.sparse_D:
                        real_B_all, real_B_skipped = get_skipped_frames_sparse(
                            real_B_all, real_B, t_scales, tD, n_frames_load, i)
                        fake_B_all, fake_B_skipped = get_skipped_frames_sparse(
                            fake_B_all, fake_B, t_scales, tD, n_frames_load, i)
                        flow_ref_all, flow_ref_skipped = get_skipped_frames_sparse(
                            flow_ref_all,
                            flow_ref,
                            t_scales,
                            tD,
                            n_frames_load,
                            i,
                            is_flow=True)
                        conf_ref_all, conf_ref_skipped = get_skipped_frames_sparse(
                            conf_ref_all,
                            conf_ref,
                            t_scales,
                            tD,
                            n_frames_load,
                            i,
                            is_flow=True)
                    else:
                        real_B_all, real_B_skipped = get_skipped_frames(
                            real_B_all, real_B, t_scales, tD)
                        fake_B_all, fake_B_skipped = get_skipped_frames(
                            fake_B_all, fake_B, t_scales, tD)
                        flow_ref_all, conf_ref_all, flow_ref_skipped, conf_ref_skipped = get_skipped_flows(
                            flowNet, flow_ref_all, conf_ref_all,
                            real_B_skipped, flow_ref, conf_ref, t_scales, tD)

                # run discriminator for each temporal scale
                for s in range(t_scales):
                    if real_B_skipped[s] is not None:
                        losses = modelD(s + 1, [
                            real_B_skipped[s], fake_B_skipped[s],
                            flow_ref_skipped[s], conf_ref_skipped[s]
                        ])
                        losses = [
                            torch.mean(x) if not isinstance(x, int) else x
                            for x in losses
                        ]
                        loss_dict_T.append(
                            dict(zip(modelD.module.loss_names_T, losses)))

                # collect losses
                loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5
                loss_G = loss_dict['G_GAN'] + loss_dict[
                    'G_GAN_Feat'] + loss_dict['G_VGG']
                loss_G += loss_dict['G_Warp'] + loss_dict[
                    'F_Flow'] + loss_dict['F_Warp'] + loss_dict['W']
                if opt.add_face_disc:
                    loss_G += loss_dict['G_f_GAN'] + loss_dict['G_f_GAN_Feat']
                    loss_D += (loss_dict['D_f_fake'] +
                               loss_dict['D_f_real']) * 0.5

                # collect temporal losses
                loss_D_T = []
                t_scales_act = min(t_scales, len(loss_dict_T))
                for s in range(t_scales_act):
                    loss_G += loss_dict_T[s]['G_T_GAN'] + loss_dict_T[s][
                        'G_T_GAN_Feat'] + loss_dict_T[s]['G_T_Warp']
                    loss_D_T.append((loss_dict_T[s]['D_T_fake'] +
                                     loss_dict_T[s]['D_T_real']) * 0.5)

                ###################################### Backward Pass #################################
                optimizer_G = modelG.module.optimizer_G
                optimizer_D = modelD.module.optimizer_D
                # update generator weights
                optimizer_G.zero_grad()
                loss_G.backward()
                optimizer_G.step()

                # update discriminator weights
                # individual frame discriminator
                optimizer_D.zero_grad()
                loss_D.backward()
                optimizer_D.step()
                # temporal discriminator
                for s in range(t_scales_act):
                    optimizer_D_T = getattr(modelD.module,
                                            'optimizer_D_T' + str(s))
                    optimizer_D_T.zero_grad()
                    loss_D_T[s].backward()
                    optimizer_D_T.step()

            if opt.debug:
                call([
                    "nvidia-smi", "--format=csv",
                    "--query-gpu=memory.used,memory.free"
                ])

            ############## Display results and errors ##########
            ### print out errors
            if total_steps % opt.print_freq == 0:
                t = (time.time() - iter_start_time) / opt.print_freq
                errors = {
                    k: v.data.item() if not isinstance(v, int) else v
                    for k, v in loss_dict.items()
                }
                for s in range(len(loss_dict_T)):
                    errors.update({
                        k + str(s):
                        v.data.item() if not isinstance(v, int) else v
                        for k, v in loss_dict_T[s].items()
                    })
                visualizer.print_current_errors(epoch, epoch_iter, errors, t)
                visualizer.plot_current_errors(errors, total_steps)

            ### display output images
            if save_fake:
                if opt.label_nc != 0:
                    input_image = util.tensor2label(real_A[0, -1],
                                                    opt.label_nc)
                elif opt.dataset_mode == 'pose':
                    input_image = util.tensor2im(real_A[0, -1, :3])
                    if real_A.size()[2] == 6:
                        input_image2 = util.tensor2im(real_A[0, -1, 3:])
                        input_image[input_image2 != 0] = input_image2[
                            input_image2 != 0]
                else:
                    c = 3 if opt.input_nc == 3 else 1
                    input_image = util.tensor2im(real_A[0, -1, :c],
                                                 normalize=False)
                if opt.use_instance:
                    edges = util.tensor2im(real_A[0, -1, -1:, ...],
                                           normalize=False)
                    input_image += edges[:, :, np.newaxis]

                if opt.add_face_disc:
                    ys, ye, xs, xe = modelD.module.get_face_region(real_A[0,
                                                                          -1:])
                    if ys is not None:
                        input_image[ys, xs:xe, :] = input_image[
                            ye, xs:xe, :] = input_image[
                                ys:ye, xs, :] = input_image[ys:ye, xe, :] = 255

                visual_list = [
                    ('input_image', input_image),
                    ('fake_image', util.tensor2im(fake_B[0, -1])),
                    ('fake_first_image', util.tensor2im(fake_B_first)),
                    ('fake_raw_image', util.tensor2im(fake_B_raw[0, -1])),
                    ('real_image', util.tensor2im(real_B[0, -1])),
                    ('flow_ref', util.tensor2flow(flow_ref[0, -1])),
                    ('conf_ref',
                     util.tensor2im(conf_ref[0, -1], normalize=False))
                ]
                if flow is not None:
                    visual_list += [('flow', util.tensor2flow(flow[0, -1])),
                                    ('weight',
                                     util.tensor2im(weight[0, -1],
                                                    normalize=False))]
                visuals = OrderedDict(visual_list)
                visualizer.display_current_results(visuals, epoch, total_steps)

            ### save latest model
            if total_steps % opt.save_latest_freq == 0:
                visualizer.vis_print(
                    'saving the latest model (epoch %d, total_steps %d)' %
                    (epoch, total_steps))
                modelG.module.save('latest')
                modelD.module.save('latest')
                np.savetxt(iter_path, (epoch, epoch_iter),
                           delimiter=',',
                           fmt='%d')

            if epoch_iter > dataset_size - opt.batchSize:
                epoch_iter = 0
                break

        # end of epoch
        iter_end_time = time.time()
        visualizer.vis_print('End of epoch %d / %d \t Time Taken: %d sec' %
                             (epoch, opt.niter + opt.niter_decay,
                              time.time() - epoch_start_time))

        ### save model for this epoch
        if epoch % opt.save_epoch_freq == 0:
            visualizer.vis_print(
                'saving the model at the end of epoch %d, iters %d' %
                (epoch, total_steps))
            modelG.module.save('latest')
            modelD.module.save('latest')
            modelG.module.save(epoch)
            modelD.module.save(epoch)
            np.savetxt(iter_path, (epoch + 1, 0), delimiter=',', fmt='%d')

        ### linearly decay learning rate after certain iterations
        if epoch > opt.niter:
            modelG.module.update_learning_rate(epoch)
            modelD.module.update_learning_rate(epoch)

        ### gradually grow training sequence length
        if (epoch % opt.niter_step) == 0:
            data_loader.dataset.update_training_batch(epoch // opt.niter_step)
            modelG.module.update_training_batch(epoch // opt.niter_step)

        ### finetune all scales
        if (opt.n_scales_spatial > 1) and (opt.niter_fix_global != 0) and (
                epoch == opt.niter_fix_global):
            modelG.module.update_fixed_params()