Ejemplo n.º 1
0
def generate(config, writer, logger):
    config = config.opt

    config.distributed = False

    data_set = CreateDataLoader(config).load_data()
    model = create_model(config)
    visualizer = Visualizer(config)

    web_dir = os.path.join(config.results_dir, config.name,
                           '%s_%s' % (config.phase, config.which_epoch))

    webpage = html.HTML(
        web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' %
        (config.name, config.phase, config.which_epoch))

    is_first = True

    average_tensor = utils.load_average_img(config)
    prev_generated = average_tensor.view(1, *average_tensor.shape)

    for data in tqdm(data_set):
        if config.no_temporal_smoothing:
            data['label'] = data['label'][:, :1]
            assert data['label'].shape[1] == 1

            generated = model.inference(data['label'].cuda(),
                                        data['inst'].cuda())
            visuals = OrderedDict([('input_label',
                                    util.tensor2label(data['label'][0],
                                                      config.label_nc)),
                                   ('synthesized_image',
                                    util.tensor2im(generated.data[0]))])

            img_path = data['path']
            visualizer.save_images(webpage, visuals, img_path)
        else:
            data['label'] = data['label'][:, :1]
            assert data['label'].shape[1] == 1

            generated = model.inference(data['label'].cuda(),
                                        data['inst'].cuda(),
                                        prev_generated.cuda())

            prev_generated = generated
            is_first = False

            visuals = OrderedDict([('input_label',
                                    util.tensor2label(data['label'][0],
                                                      config.label_nc)),
                                   ('synthesized_image',
                                    util.tensor2im(generated.data[0]))])

            img_path = data['path']
            visualizer.save_images(webpage, visuals, img_path)
Ejemplo n.º 2
0
def transfer(opt=test_opt):
    ''' Transfer source video to target video '''
    
    iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')
    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    visualizer = Visualizer(opt)
    print(f'# testing images = {len(data_loader)}')
    
    # create website
    web_dir = os.path.join(opt.results_dir, opt.name, f'{opt.phase}_{opt.which_epoch}'))
    webpage = html.HTML(web_dir, f'Experiment = {opt.name}, Phase = {opt.phase}, Epoch = {opt.which_epoch}')

    model = create_model(opt)
    model = model.cuda()

    for data in tqdm(dataset):
        minibatch = 1 
        generated = model.inference(data['label'], data['inst'])
            
        visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)),
                               ('synthesized_image', util.tensor2im(generated.data[0]))])
        img_path = data['path']
        visualizer.save_images(webpage, visuals, img_path)
    
    webpage.save()
    torch.cuda.empty_cache()
Ejemplo n.º 3
0
def gan_main():
    opt = TestOptions().parse(save=False)
    opt.nThreads = 1  # test code only supports nThreads = 1
    opt.batchSize = 1  # test code only supports batchSize = 1
    opt.serial_batches = True  # no shuffle
    opt.no_flip = True  # no flip

    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    model = create_model(opt)
    visualizer = Visualizer(opt)
    # create website
    web_dir = os.path.join(opt.results_dir, opt.name,
                           '%s_%s' % (opt.phase, opt.which_epoch))
    webpage = html.HTML(
        web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' %
        (opt.name, opt.phase, opt.which_epoch))
    # test
    for i, data in enumerate(dataset):
        if i >= opt.how_many:
            break
        generated = model.inference(data['label'], data['inst'])
        visuals = OrderedDict([
            ('input_label', util.tensor2label(data['label'][0], opt.label_nc)),
            ('synthesized_image', util.tensor2im(generated.data[0]))
        ])
        img_path = data['path']
        print('process image... %s' % img_path)
        visualizer.save_images(webpage, visuals, img_path)

    webpage.save()
Ejemplo n.º 4
0
 def get_current_visuals(self):
     return OrderedDict([('input_label',
                          util.tensor2label(self.input_label,
                                            self.opt.label_nc)),
                         ('input_image', util.tensor2im(self.input_image)),
                         ('real_image', util.tensor2im(self.real_image)),
                         ('synthesized_image',
                          util.tensor2im(self.fake_image))])
Ejemplo n.º 5
0
def generate_label_color(inputs):
    label_batch = []
    for i in range(len(inputs)):
        label_batch.append(util.tensor2label(inputs[i], NC))
    label_batch = np.array(label_batch)
    label_batch = label_batch * 2 - 1
    input_label = torch.from_numpy(label_batch)

    return input_label
Ejemplo n.º 6
0
    def get_current_visuals(self, getLabel=False):                              
        mask = self.mask     
        if self.mask is not None:
            mask = np.transpose(self.mask[0].cpu().float().numpy(), (1,2,0)).astype(np.uint8)        

        dict_list = [('fake_image', self.fake_image), ('mask', mask)]

        if getLabel: # only output label map if needed to save bandwidth
            label = util.tensor2label(self.net_input.data[0], self.opt.label_nc)                    
            dict_list += [('label', label)]

        return OrderedDict(dict_list)
Ejemplo n.º 7
0
    def get_current_visuals(self, getLabel=False):                              
        mask = self.mask     
        if self.mask is not None:
            mask = np.transpose(self.mask[0].cpu().float().numpy(), (1,2,0)).astype(np.uint8)        

        dict_list = [('fake_image', self.fake_image), ('mask', mask)]

        if getLabel: # only output label map if needed to save bandwidth
            label = util.tensor2label(self.net_input.data[0], self.opt.label_nc)                    
            dict_list += [('label', label)]

        return OrderedDict(dict_list)
Ejemplo n.º 8
0
def single_generation_from_update(save_path,
                                  fname,
                                  features,
                                  checkpoints_dir,
                                  classname,
                                  black=True):
    ''' Generate decoded segmentation map from input features
        Args: save_path (str), save generated masks to path
              fname (str), save generated masks with fname
              features (numpy array): input features to be decoded
              checkpoints_dir (str), load VAE weights from path
              classname (str), label taxonomy defined by dataset with classname
              black (boolean), black is True for regular generation;
                               black is False for debugging, thus the generated mask
                               is not in the format for cGAN input
    '''
    vae_opt = initialize_option(classname)
    vae_opt.checkpoints_dir = checkpoints_dir

    vae_util.mkdirs(save_path)

    if vae_opt.share_decoder and vae_opt.share_encoder:
        if vae_opt.separate_clothing_unrelated:
            from models.separate_clothing_encoder_models import create_model as vae_create_model
        else:
            print('Only supports separating clothing and clothing-irrelevant')
            raise NotImplementedError
    else:
        print('Only supports sharing encoder and decoder among all parts')
        raise NotImplementedError

    model = vae_create_model(vae_opt)
    generated = model.generate_from_random(torch.Tensor(features))  #.cuda())

    if black:
        vae_util.save_image(
            vae_util.tensor2label_black(generated.data[0],
                                        vae_opt.output_nc,
                                        normalize=True),
            os.path.join(save_path, '%s.png' % (fname)))
    else:
        vae_util.save_image(
            vae_util.tensor2label(generated.data[0],
                                  vae_opt.output_nc,
                                  normalize=True),
            os.path.join(save_path, '%s.png' % (fname)))
Ejemplo n.º 9
0
def test_transfer(source_dir, run_name, temporal_smoothing=False, live_run_name=None):
    import src.config.test_opt as opt

    opt.name = run_name
    opt.dataroot = source_dir
    opt.temporal_smoothing = temporal_smoothing
    if device == torch.device('cpu'):
        opt.gpu_ids = []
    else:
        os.environ['CUDA_VISIBLE_DEVICES'] = "0"

    opt.checkpoints_dir = os.path.join(dir_name, '../../checkpoints')
    opt.results_dir = os.path.join(dir_name, '../../results')

    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()

    #print(opt.load_pretrain)
    model = create_model(opt)

    if live_run_name is not None:
        opt.name = live_run_name
    visualizer = Visualizer(opt)

    web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch))
    webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch))
    generated = None
    for data in tqdm(dataset):
        if temporal_smoothing:
            if generated is None:
                previous_frame = torch.zeros((1, 3, opt.loadSize, opt.loadSize))
                generated = model.inference(data['label'], data['inst'], previous_frame)
            else:
                generated = model.inference(data['label'], data['inst'], generated)

        else:
            generated = model.inference(data['label'], data['inst'])

        visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)),
                               ('synthesized_image', util.tensor2im(generated.data[0]))])
        img_path = data['path']
        visualizer.save_images(webpage, visuals, img_path)
    webpage.save()
    torch.cuda.empty_cache()
Ejemplo n.º 10
0
 def get_current_visuals1(self):
     return OrderedDict([
         ('input_label',
          util.tensor2label(self.input_label, self.opt.label_nc)),
         ('init_predict_label',
          util.tensor2label(self.init_predict_label, self.opt.label_nc)),
         ('input_label1',
          util.tensor2label(self.input_label1, self.opt.label_nc)),
         ('predict_label',
          util.tensor2label(self.predict_label, self.opt.label_nc)),
         ('ori_predict_label',
          util.tensor2label(self.ori_predict_label, self.opt.label_nc)),
         ('target_label',
          util.tensor2label(self.target_label, self.opt.label_nc)),
     ])
Ejemplo n.º 11
0
def visualize(data, generated, opt):
    if opt.model == 'pix2pixHD':

        visuals = OrderedDict([
            ('input_label', util.tensor2label(data['label'][0], opt.label_nc)),
            ('synthesized_image', util.tensor2im(generated.data[0])),
            ('real_image', util.tensor2im(data['image'][0]))
        ])
        visualizer.display_current_results(visuals, epoch, total_steps)

    elif opt.model == 'pix2pixHDts':

        syn = generated[0].data[0]
        inputs = torch.cat((data['label'], data['next_label']), dim=3)
        targets = torch.cat((data['image'], data['next_image']), dim=3)
        visuals = OrderedDict([('input_label',
                                util.tensor2im(inputs[0], normalize=False)),
                               ('synthesized_image', util.tensor2im(syn)),
                               ('real_image', util.tensor2im(targets[0]))])
        if opt.face:  #display face generator on tensorboard
            minx, miny, maxx, maxy = data['face_coords'][0]
            res_face = generated[2].data[0]
            syn_face = generated[1].data[0]
            preres = generated[3].data[0]
            visuals = OrderedDict([
                ('input_label', util.tensor2im(inputs[0], normalize=False)),
                ('synthesized_image', util.tensor2im(syn)),
                ('synthesized_face', util.tensor2im(syn_face)),
                ('residual', util.tensor2im(res_face)),
                ('real_face',
                 util.tensor2im(data['image'][0][:, miny:maxy, minx:maxx])),
                # ('pre_residual', util.tensor2im(preres)),
                # ('pre_residual_face', util.tensor2im(preres[:, miny:maxy, minx:maxx])),
                ('input_face',
                 util.tensor2im(data['label'][0][:, miny:maxy, minx:maxx],
                                normalize=False)),
                ('real_image', util.tensor2im(targets[0]))
            ])
        visualizer.display_current_results(visuals, epoch, total_steps)
Ejemplo n.º 12
0
        loss_D.backward()
        model.module.optimizer_D.step()

        #call(["nvidia-smi", "--format=csv", "--query-gpu=memory.used,memory.free"]) 

        ############## Display results and errors ##########
        ### print out errors
        if total_steps % opt.print_freq == print_delta:
            errors = {k: v.data[0] if not isinstance(v, int) else v for k, v in loss_dict.items()}
            t = (time.time() - iter_start_time) / opt.batchSize
            visualizer.print_current_errors(epoch, epoch_iter, errors, t)
            visualizer.plot_current_errors(errors, total_steps)

        ### display output images
        if save_fake:
            visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)),
                                   ('synthesized_image', util.tensor2im(generated.data[0])),
                                   ('real_image', util.tensor2im(data['image'][0]))])
            visualizer.display_current_results(visuals, epoch, total_steps)

        ### save latest model
        if total_steps % opt.save_latest_freq == save_delta:
            print('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps))
            model.module.save('latest')            
            np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d')
       
    # end of epoch 
    iter_end_time = time.time()
    print('End of epoch %d / %d \t Time Taken: %d sec' %
          (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
Ejemplo n.º 13
0
def train():
    opt = TrainOptions().parse()
    if opt.debug:
        opt.display_freq = 1
        opt.print_freq = 1
        opt.nThreads = 1

    ### initialize dataset
    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    if opt.dataset_mode == 'pose':
        print('#training frames = %d' % dataset_size)
    else:
        print('#training videos = %d' % dataset_size)

    ### initialize models
    modelG, modelD, flowNet = create_model(opt)
    visualizer = Visualizer(opt)

    iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')
    ### if continue training, recover previous states
    if opt.continue_train:
        try:
            start_epoch, epoch_iter = np.loadtxt(iter_path,
                                                 delimiter=',',
                                                 dtype=int)
        except:
            start_epoch, epoch_iter = 1, 0
        print('Resuming from epoch %d at iteration %d' %
              (start_epoch, epoch_iter))
        if start_epoch > opt.niter:
            modelG.module.update_learning_rate(start_epoch - 1)
            modelD.module.update_learning_rate(start_epoch - 1)
        if (opt.n_scales_spatial > 1) and (opt.niter_fix_global != 0) and (
                start_epoch > opt.niter_fix_global):
            modelG.module.update_fixed_params()
        if start_epoch > opt.niter_step:
            data_loader.dataset.update_training_batch(
                (start_epoch - 1) // opt.niter_step)
            modelG.module.update_training_batch(
                (start_epoch - 1) // opt.niter_step)
    else:
        start_epoch, epoch_iter = 1, 0

    ### set parameters
    n_gpus = opt.n_gpus_gen // opt.batchSize  # number of gpus used for generator for each batch
    tG, tD = opt.n_frames_G, opt.n_frames_D
    tDB = tD * opt.output_nc
    s_scales = opt.n_scales_spatial
    t_scales = opt.n_scales_temporal
    input_nc = 1 if opt.label_nc != 0 else opt.input_nc
    output_nc = opt.output_nc

    opt.print_freq = lcm(opt.print_freq, opt.batchSize)
    total_steps = (start_epoch - 1) * dataset_size + epoch_iter
    total_steps = total_steps // opt.print_freq * opt.print_freq

    ### real training starts here
    for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
        for idx, data in enumerate(dataset, start=epoch_iter):
            if total_steps % opt.print_freq == 0:
                iter_start_time = time.time()
            total_steps += opt.batchSize
            epoch_iter += opt.batchSize

            # whether to collect output images
            save_fake = total_steps % opt.display_freq == 0

            _, n_frames_total, height, width = data['B'].size(
            )  # n_frames_total = n_frames_load * n_loadings + tG - 1
            n_frames_total = n_frames_total // opt.output_nc
            n_frames_load = opt.max_frames_per_gpu * n_gpus  # number of total frames loaded into GPU at a time for each batch
            n_frames_load = min(n_frames_load, n_frames_total - tG + 1)
            t_len = n_frames_load + tG - 1  # number of loaded frames plus previous frames

            fake_B_last = None  # the last generated frame from previous training batch (which becomes input to the next batch)
            real_B_all, fake_B_all, flow_ref_all, conf_ref_all = None, None, None, None  # all real/generated frames so far
            real_B_skipped, fake_B_skipped = [None] * t_scales, [
                None
            ] * t_scales  # temporally subsampled frames
            flow_ref_skipped, conf_ref_skipped = [None] * t_scales, [
                None
            ] * t_scales  # temporally subsampled flows

            for i in range(0, n_frames_total - t_len + 1, n_frames_load):
                # 5D tensor: batchSize, # of frames, # of channels, height, width
                input_A = Variable(
                    data['A'][:, i * input_nc:(i + t_len) * input_nc,
                              ...]).view(-1, t_len, input_nc, height, width)
                input_B = Variable(
                    data['B'][:, i * output_nc:(i + t_len) * output_nc,
                              ...]).view(-1, t_len, output_nc, height, width)
                inst_A = Variable(data['inst'][:, i:i + t_len, ...]).view(
                    -1, t_len, 1, height,
                    width) if len(data['inst'].size()) > 2 else None

                ###################################### Forward Pass ##########################
                ####### generator
                fake_B, fake_B_raw, flow, weight, real_A, real_Bp, fake_B_last = modelG(
                    input_A, input_B, inst_A, fake_B_last)

                if i == 0:
                    fake_B_first = fake_B[
                        0, 0]  # the first generated image in this sequence
                real_B_prev, real_B = real_Bp[:, :
                                              -1], real_Bp[:,
                                                           1:]  # the collection of previous and current real frames

                ####### discriminator
                ### individual frame discriminator
                flow_ref, conf_ref = flowNet(
                    real_B, real_B_prev)  # reference flows and confidences
                fake_B_prev = real_B_prev[:, 0:
                                          1] if fake_B_last is None else fake_B_last[
                                              0][:, -1:]
                if fake_B.size()[1] > 1:
                    fake_B_prev = torch.cat(
                        [fake_B_prev, fake_B[:, :-1].detach()], dim=1)

                losses = modelD(
                    0,
                    reshape([
                        real_B, fake_B, fake_B_raw, real_A, real_B_prev,
                        fake_B_prev, flow, weight, flow_ref, conf_ref
                    ]))
                losses = [
                    torch.mean(x) if x is not None else 0 for x in losses
                ]
                loss_dict = dict(zip(modelD.module.loss_names, losses))

                ### temporal discriminator
                loss_dict_T = []
                # get skipped frames for each temporal scale
                if t_scales > 0:
                    real_B_all, real_B_skipped = get_skipped_frames(
                        real_B_all, real_B, t_scales, tD)
                    fake_B_all, fake_B_skipped = get_skipped_frames(
                        fake_B_all, fake_B, t_scales, tD)
                    flow_ref_all, conf_ref_all, flow_ref_skipped, conf_ref_skipped = get_skipped_flows(
                        flowNet, flow_ref_all, conf_ref_all, real_B_skipped,
                        flow_ref, conf_ref, t_scales, tD)

                # run discriminator for each temporal scale
                for s in range(t_scales):
                    if real_B_skipped[s] is not None and real_B_skipped[
                            s].size()[1] == tD:
                        losses = modelD(s + 1, [
                            real_B_skipped[s], fake_B_skipped[s],
                            flow_ref_skipped[s], conf_ref_skipped[s]
                        ])
                        losses = [
                            torch.mean(x) if not isinstance(x, int) else x
                            for x in losses
                        ]
                        loss_dict_T.append(
                            dict(zip(modelD.module.loss_names_T, losses)))

                # collect losses
                loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5
                loss_G = loss_dict['G_GAN'] + loss_dict[
                    'G_GAN_Feat'] + loss_dict['G_VGG']
                loss_G += loss_dict['G_Warp'] + loss_dict[
                    'F_Flow'] + loss_dict['F_Warp'] + loss_dict['W']
                if opt.add_face_disc:
                    loss_G += loss_dict['G_f_GAN'] + loss_dict['G_f_GAN_Feat']
                    loss_D += (loss_dict['D_f_fake'] +
                               loss_dict['D_f_real']) * 0.5

                # collect temporal losses
                loss_D_T = []
                t_scales_act = min(t_scales, len(loss_dict_T))
                for s in range(t_scales_act):
                    loss_G += loss_dict_T[s]['G_T_GAN'] + loss_dict_T[s][
                        'G_T_GAN_Feat'] + loss_dict_T[s]['G_T_Warp']
                    loss_D_T.append((loss_dict_T[s]['D_T_fake'] +
                                     loss_dict_T[s]['D_T_real']) * 0.5)

                ###################################### Backward Pass #################################
                optimizer_G = modelG.module.optimizer_G
                optimizer_D = modelD.module.optimizer_D
                # update generator weights
                optimizer_G.zero_grad()
                loss_G.backward()
                optimizer_G.step()

                # update discriminator weights
                # individual frame discriminator
                optimizer_D.zero_grad()
                loss_D.backward()
                optimizer_D.step()
                # temporal discriminator
                for s in range(t_scales_act):
                    optimizer_D_T = getattr(modelD.module,
                                            'optimizer_D_T' + str(s))
                    optimizer_D_T.zero_grad()
                    loss_D_T[s].backward()
                    optimizer_D_T.step()

            if opt.debug:
                call([
                    "nvidia-smi", "--format=csv",
                    "--query-gpu=memory.used,memory.free"
                ])

            ############## Display results and errors ##########
            ### print out errors
            if total_steps % opt.print_freq == 0:
                t = (time.time() - iter_start_time) / opt.print_freq
                errors = {
                    k: v.data.item() if not isinstance(v, int) else v
                    for k, v in loss_dict.items()
                }
                for s in range(len(loss_dict_T)):
                    errors.update({
                        k + str(s):
                        v.data.item() if not isinstance(v, int) else v
                        for k, v in loss_dict_T[s].items()
                    })
                visualizer.print_current_errors(epoch, epoch_iter, errors, t)
                visualizer.plot_current_errors(errors, total_steps)

            ### display output images
            if save_fake:
                if opt.label_nc != 0:
                    input_image = util.tensor2label(real_A[0, -1],
                                                    opt.label_nc)
                elif opt.dataset_mode == 'pose':
                    input_image = util.tensor2im(real_A[0, -1, :3],
                                                 normalize=False)
                    if real_A.size()[2] == 6:
                        input_image2 = util.tensor2im(real_A[0, -1, 3:],
                                                      normalize=False)
                        input_image[input_image2 != 0] = input_image2[
                            input_image2 != 0]
                else:
                    c = 3 if opt.input_nc == 3 else 1
                    input_image = util.tensor2im(real_A[0, -1, :c],
                                                 normalize=False)
                if opt.use_instance:
                    edges = util.tensor2im(real_A[0, -1, -1:, ...],
                                           normalize=False)
                    input_image += edges[:, :, np.newaxis]

                if opt.add_face_disc:
                    ys, ye, xs, xe = modelD.module.get_face_region(real_A[0,
                                                                          -1:])
                    if ys is not None:
                        input_image[ys, xs:xe, :] = input_image[
                            ye, xs:xe, :] = input_image[
                                ys:ye, xs, :] = input_image[ys:ye, xe, :] = 255

                visual_list = [
                    ('input_image', input_image),
                    ('fake_image', util.tensor2im(fake_B[0, -1])),
                    ('fake_first_image', util.tensor2im(fake_B_first)),
                    ('fake_raw_image', util.tensor2im(fake_B_raw[0, -1])),
                    ('real_image', util.tensor2im(real_B[0, -1])),
                    ('flow_ref', util.tensor2flow(flow_ref[0, -1])),
                    ('conf_ref',
                     util.tensor2im(conf_ref[0, -1], normalize=False))
                ]
                if flow is not None:
                    visual_list += [('flow', util.tensor2flow(flow[0, -1])),
                                    ('weight',
                                     util.tensor2im(weight[0, -1],
                                                    normalize=False))]
                visuals = OrderedDict(visual_list)
                visualizer.display_current_results(visuals, epoch, total_steps)

            ### save latest model
            if total_steps % opt.save_latest_freq == 0:
                visualizer.vis_print(
                    'saving the latest model (epoch %d, total_steps %d)' %
                    (epoch, total_steps))
                modelG.module.save('latest')
                modelD.module.save('latest')
                np.savetxt(iter_path, (epoch, epoch_iter),
                           delimiter=',',
                           fmt='%d')

            if epoch_iter > dataset_size - opt.batchSize:
                epoch_iter = 0
                break

        # end of epoch
        iter_end_time = time.time()
        visualizer.vis_print('End of epoch %d / %d \t Time Taken: %d sec' %
                             (epoch, opt.niter + opt.niter_decay,
                              time.time() - epoch_start_time))

        ### save model for this epoch
        if epoch % opt.save_epoch_freq == 0:
            visualizer.vis_print(
                'saving the model at the end of epoch %d, iters %d' %
                (epoch, total_steps))
            modelG.module.save('latest')
            modelD.module.save('latest')
            modelG.module.save(epoch)
            modelD.module.save(epoch)
            np.savetxt(iter_path, (epoch + 1, 0), delimiter=',', fmt='%d')

        ### linearly decay learning rate after certain iterations
        if epoch > opt.niter:
            modelG.module.update_learning_rate(epoch)
            modelD.module.update_learning_rate(epoch)

        ### gradually grow training sequence length
        if (epoch % opt.niter_step) == 0:
            data_loader.dataset.update_training_batch(epoch // opt.niter_step)
            modelG.module.update_training_batch(epoch // opt.niter_step)

        ### finetune all scales
        if (opt.n_scales_spatial > 1) and (opt.niter_fix_global != 0) and (
                epoch == opt.niter_fix_global):
            modelG.module.update_fixed_params()
Ejemplo n.º 14
0
# test
loss = torch.nn.L1Loss()
losses = []
paths = []
for i, data in enumerate((dataset)):
    if i >= opt.how_many:
        break

    generated = model.fake_inference(data['image'],
                                     data['label'],
                                     data['inst'],
                                     pose=data['pose'],
                                     normal=data['normal'],
                                     depth=data['depth'])

    visuals = [('input_label', util.tensor2label(data['label'][0],
                                                 opt.label_nc)),
               ('input_inst', util.tensor2label(data['inst'][0],
                                                opt.label_nc))]
    if opt.feat_pose:
        visuals += [('input_pose',
                     util.tensor2label(data['pose'][0],
                                       opt.feat_pose_num_bins))]
    if opt.feat_normal:
        visuals += [('input_normal', util.tensor2im(data['normal'][0]))]
    if opt.feat_depth:
        visuals += [('input_depth', util.tensor2im(data['depth'][0]))]
    visuals += [
        ('real_image', util.tensor2im(data['image'][0])),
        ('%s_%s_%s' % (opt.phase, opt.which_epoch, opt.experiment_name),
         util.tensor2im(generated.data[0])),
    ]
Ejemplo n.º 15
0
save_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' %
                        (opt.phase, opt.which_epoch))
print('Doing %d frames' % len(dataset))
for i, data in enumerate(dataset):
    if i >= opt.how_many:
        break
    if data['change_seq']:
        model.fake_B_prev = None

    _, _, height, width = data['A'].size()
    A = Variable(data['A']).view(1, -1, input_nc, height, width)
    B = Variable(data['B']).view(1, -1, opt.output_nc, height,
                                 width) if len(data['B'].size()) > 2 else None
    inst = Variable(data['inst']).view(1, -1, 1, height,
                                       width) if len(data['inst'].size()) > 2 else None
    generated = model.inference(A, B, inst)

    if opt.label_nc != 0:
        real_A = util.tensor2label(generated[1], opt.label_nc)
    else:
        c = 3 if opt.input_nc == 3 else 1
        real_A = util.tensor2im(generated[1][:c], normalize=False)

    visual_list = [('real_A', real_A),
                   ('fake_B', util.tensor2im(generated[0].data[0]))]
    visuals = OrderedDict(visual_list)
    img_path = data['A_path']
    print('process image... %s' % img_path)
    visualizer.save_images(save_dir, visuals, img_path)
Ejemplo n.º 16
0
        # Get ground-truth output
        with torch.no_grad():
            generated_noattack = model.inference(data['label'], data['inst'],
                                                 data['image'])
        # Attack
        adv_image, perturb = model.attack(data['label'],
                                          data['inst'],
                                          data['image'],
                                          target=generated_noattack)
        # Get output from adversarial sample
        with torch.no_grad():
            generated, adv_img = model.inference_attack(
                data['label'], data['inst'], data['image'], perturb)

    visuals = OrderedDict([
        ('original_label', util.tensor2label(data['label'][0], opt.label_nc)),
        ('input_label', util.tensor2label(adv_img.data[0], opt.label_nc)),
        ('attacked_image', util.tensor2im(generated.data[0])),
        ('noattack', util.tensor2im(generated_noattack.data[0]))
    ])
    img_path = data['path']
    print('process image... %s' % img_path)
    visualizer.save_images(webpage, visuals, img_path)

    # Compute metrics
    l1_error += F.l1_loss(generated, generated_noattack)
    l2_error += F.mse_loss(generated, generated_noattack)
    l0_error += (generated - generated_noattack).norm(0)
    min_dist += (generated - generated_noattack).norm(float('-inf'))
    if F.mse_loss(generated, generated_noattack) > 0.05:
        n_dist += 1
Ejemplo n.º 17
0
                k: v.data[0] if not isinstance(v, int) else v
                for k, v in loss_dict.items()
            }
            t = (time.time() - iter_start_time) / opt.batchSize
            visualizer.print_current_errors(epoch,
                                            train_epoch_iter,
                                            errors,
                                            t,
                                            mode='train')
            visualizer.plot_current_errors(errors, total_steps, mode='train')

        ### display output images
        if save_train:

            visuals = OrderedDict([
                ('input_label', util.tensor2label(train_data['image'][0])),
                ('synthesized_image',
                 util.tensor2im(generated.data[0], train_data['image'][0])),
                ('real_image',
                 util.tensor2imreal(train_data['label'][0],
                                    train_data['image'][0]))
            ])

            visualizer.display_current_results(visuals, epoch, total_steps)

        ### save latest model
        if total_steps % opt.save_latest_freq == 0:
            print('saving the latest model (epoch %d, total_steps %d)' %
                  (epoch, total_steps))
            model.module.save('latest')
            np.savetxt(iter_path, (epoch, train_epoch_iter),
Ejemplo n.º 18
0
    reconstructed_comb_label = reconstructed['comb_recon_label']
    reconstructed_obj_label = reconstructed['obj_recon_label']
    #
    generated = model.generate({
        'label_map': Variable(data['label'], volatile=True),
        'mask_obj_in': Variable(data['mask_object_in'], volatile=True),
        'mask_ctx_in': Variable(data['mask_context_in'], volatile=True),
        'mask_obj_out': Variable(data['mask_object_out'], volatile=True),
        'mask_out': Variable(data['mask_out'], volatile=True),
        'mask_obj_inst': Variable(data['mask_object_inst'], volatile=True),
        'cls': Variable(data['cls'], volatile=True),
        'mask_in': Variable(data['mask_in'], volatile=True)
        })
    generated_comb_label = generated['comb_pred_label']
    generated_obj_label = generated['obj_pred_label']

    visuals = OrderedDict([
        ('image', util.tensor2im(data['image'][0])),
        ('gt_label', util.tensor2label(data['label'][0], opt.label_nc)),
        ('input_context', util.tensor2label(data['mask_context_in'][0], opt.label_nc)),
        ('mask_in', util.tensor2im(data['mask_in'][0])),
        ('reconstructed_comb_label', util.tensor2label(reconstructed_comb_label.data[0], opt.label_nc)),
        #('reconstructed_obj_label', util.tensor2im(reconstructed_obj_label.data[0])),
        ('generated_comb_label', util.tensor2label(generated_comb_label.data[0], opt.label_nc)),
        #('generated_obj_label', util.tensor2im(generated_obj_label.data[0]))])
    label_path = data['label_path']
    print('process image... %s' % label_path)
    visualizer.save_images(webpage, visuals, label_path)

webpage.save()
Ejemplo n.º 19
0
data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()
model = create_model(opt)
visualizer = Visualizer(opt)
# create website
web_dir = os.path.join(opt.results_dir, opt.name,
                       '%s_%s' % (opt.phase, opt.which_epoch))
webpage = html.HTML(
    web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' %
    (opt.name, opt.phase, opt.which_epoch))
# test
for i, data in enumerate(dataset):
    if i >= opt.how_many:
        break
    if opt.model == 'Bpgan_GAN':
        generated, latent_vector = model.inference(data['label'])
    elif opt.model == 'Bpgan_GAN_Q':
        generated, latent_vector = model.inference(data['label'],
                                                   Q_type='Hard')
    visuals = OrderedDict([
        ('input_label', util.tensor2label(data['label'][0], 0)),
        ('synthesized_image', util.tensor2im(generated.data[0])),
        ('real_image', util.tensor2im(data['image'][0]))
    ])
    print(latent_vector.shape)
    img_path = data['path']
    print('process image... %s' % img_path)
    visualizer.save_images(webpage, visuals, img_path)

webpage.save()
Ejemplo n.º 20
0
            for loss_name in loss_names:
                loss_mean_temp[loss_name] = loss_mean_temp[loss_name].item() / loss_count

            # errors = {k: v.data[0] if not isinstance(v, int) else v for k, v in loss_dict.items()}
            errors = {k: v for k, v in loss_mean_temp.items()}
            t = (time.time() - iter_start_time) / opt.batchSize
            visualizer.print_current_errors(epoch, epoch_iter, errors, t)
            visualizer.plot_current_errors(errors, total_steps)
            for loss_name in loss_names:
                loss_mean_temp[loss_name] = 0
            loss_count = 0

        ### display output images
        if save_fake:
            if opt.debug_mask_part == True:
                visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)),
                                       ('input_ori_label', util.tensor2label(data['ori_label'][0], opt.label_nc)),
                                       ('transfer_label', util.tensor2label(transfer_label.data[0], opt.label_nc)),
                                       ('transfer_image', util.tensor2im(transfer_image.data[0])),
                                       ('reconstruct_image', util.tensor2im(reconstruct.data[0])),
                                       ('real_image', util.tensor2im(data['bg_image'][0]))
                                       # ('parsing_label', util.tensor2label(label_out.data[0], opt.label_nc)),
                                       # ('real_parsing_label', util.tensor2label(real_parsing_label.data[0], opt.label_nc)),
                                       # ('reconstruct_left_eye', util.tensor2im(left_eye_reconstruct.data[0])),
                                       # ('reconstruct_right_eye', util.tensor2im(right_eye_reconstruct.data[0])),
                                       # ('reconstruct_skin', util.tensor2im(skin_reconstruct.data[0])),
                                       # ('reconstruct_hair', util.tensor2im(hair_reconstruct.data[0])),
                                       # ('reconstruct_mouth', util.tensor2im(mouth_reconstruct.data[0])),
                                       # ('mask_lefteye', util.tensor2im(left_eye_real.data[0]))
                                       ])
            else:
Ejemplo n.º 21
0
                          verbose=True)
        exit(0)
    minibatch = 1
    if opt.engine:
        generated = run_trt_engine(opt.engine, minibatch,
                                   [data['label'], data['inst']])
    elif opt.onnx:
        generated = run_onnx(opt.onnx, opt.data_type, minibatch,
                             [data['label'], data['inst']])
    else:
        generated = model.inference(data['label'], data['inst'], data['image'])

    img_path = data['path']
    print('process image %d... %s' % (i, img_path))

    orig_im = util.tensor2label(data['label'][0], opt.label_nc)
    mask = util.tensor2label(data['mask'][0], opt.label_nc)
    synthesized_im = util.tensor2im(generated.data[0])
    masked_im = apply_mask(orig_im, synthesized_im, mask)
    enhanced = enhance_brightening(orig_im, masked_im)
    enhanced1 = enhance_brightening(orig_im, masked_im, factor=1.75)
    enhanced2 = enhance_brightening(orig_im, masked_im, factor=2)

    recolored = recolor_im(orig_im, mask, dst_im)
    recolored_from_synth = recolor_im(synthesized_im, mask, dst_im)

    # visuals = OrderedDict([('input_label', orig_im),
    #                        ('synthesized_image', synthesized_im),
    #                        ('masked_image', masked_im),
    #                        ('enhanced', enhanced)
    #                        ])
Ejemplo n.º 22
0
    elif opt.data_type == 8:
        data['label'] = data['label'].uint8()
        data['inst'] = data['inst'].uint8()
    if opt.export_onnx:
        print("Exporting to ONNX: ", opt.export_onnx)
        assert opt.export_onnx.endswith(
            "onnx"), "Export model file should end with .onnx"
        torch.onnx.export(model, [data['label'], data['inst']],
                          opt.export_onnx,
                          verbose=True)
        exit(0)
    minibatch = 1
    if opt.engine:
        generated = run_trt_engine(opt.engine, minibatch,
                                   [data['label'], data['inst']])
    elif opt.onnx:
        generated = run_onnx(opt.onnx, opt.data_type, minibatch,
                             [data['label'], data['inst']])
    else:
        generated = model.inference(data['label'], data['inst'], data['image'])

    visuals = OrderedDict([
        ('input_label', util.tensor2label(data['label'][0], opt.label_nc)),
        ('synthesized_image', util.tensor2im(generated.data[0]))
    ])
    img_path = data['path']
    print('process image... %s' % img_path)
    visualizer.save_images(webpage, visuals, img_path)

webpage.save()
Ejemplo n.º 23
0
        ############## Display results and errors ##########
        ### print out errors
        if total_steps % opt.print_freq == print_delta:
            errors = {k: v.data.item() if not isinstance(v, int) else v for k, v in loss_dict.items()}            
            t = (time.time() - iter_start_time) / opt.print_freq
            visualizer.print_current_errors(epoch, epoch_iter, errors, t)
            visualizer.plot_current_errors(errors, total_steps)
            #call(["nvidia-smi", "--format=csv", "--query-gpu=memory.used,memory.free"]) 

        ### display output images
        if save_fake:
            
            ricker = util.tensor2im(data['image'][0])
            genereted = util.tensor2im(generated.data[0])
            psf = util.tensor2label(data['label'][0], opt.label_nc)
            
            diff_image = ricker - genereted
            print("Acc:", score(ricker, genereted))
            
            visuals = OrderedDict([('PSF', psf),
                                   ('Genereted', genereted),
                                   ('Ricker', ricker),
                                   ('Difference', diff_image)])
            visualizer.display_current_results(visuals, epoch, total_steps)

        ### save latest model
        """if total_steps % opt.save_latest_freq == save_delta:
            print('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps))
            model.module.save('latest')            
            np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d')"""
Ejemplo n.º 24
0
def main():
    opt = TrainOptions().parse()
    iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')
    if opt.continue_train:
        try:
            start_epoch, epoch_iter = np.loadtxt(iter_path,
                                                 delimiter=',',
                                                 dtype=int)
        except:
            start_epoch, epoch_iter = 1, 0
        print('Resuming from epoch %d at iteration %d' %
              (start_epoch, epoch_iter))
    else:
        start_epoch, epoch_iter = 1, 0

    opt.print_freq = lcm(opt.print_freq, opt.batchSize)
    if opt.debug:
        opt.display_freq = 1
        opt.print_freq = 1
        opt.niter = 1
        opt.niter_decay = 0
        opt.max_dataset_size = 10

    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('#training images = %d' % dataset_size)

    model = create_model(opt)
    visualizer = Visualizer(opt)
    if opt.fp16:
        from apex import amp
        model, [optimizer_G, optimizer_D
                ] = amp.initialize(model,
                                   [model.optimizer_G, model.optimizer_D],
                                   opt_level='O1')
        model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)
    else:
        optimizer_G, optimizer_D = model.module.optimizer_G, model.module.optimizer_D

    total_steps = (start_epoch - 1) * dataset_size + epoch_iter

    display_delta = total_steps % opt.display_freq
    print_delta = total_steps % opt.print_freq
    save_delta = total_steps % opt.save_latest_freq

    for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
        if epoch != start_epoch:
            epoch_iter = epoch_iter % dataset_size
        for i, data in enumerate(dataset, start=epoch_iter):
            if total_steps % opt.print_freq == print_delta:
                iter_start_time = time.time()
            total_steps += opt.batchSize
            epoch_iter += opt.batchSize

            # whether to collect output images
            save_fake = total_steps % opt.display_freq == display_delta

            ############## Forward Pass ######################
            losses, generated = model(Variable(data['label']),
                                      Variable(data['inst']),
                                      Variable(data['image']),
                                      Variable(data['feat']),
                                      infer=save_fake)

            # sum per device losses
            losses = [
                torch.mean(x) if not isinstance(x, int) else x for x in losses
            ]
            loss_dict = dict(zip(model.module.loss_names, losses))

            # calculate final loss scalar
            loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5
            loss_G = loss_dict['G_GAN'] + loss_dict.get(
                'G_GAN_Feat', 0) + loss_dict.get('G_VGG', 0)

            ############### Backward Pass ####################
            # update generator weights
            optimizer_G.zero_grad()
            if opt.fp16:
                with amp.scale_loss(loss_G, optimizer_G) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss_G.backward()
            optimizer_G.step()

            # update discriminator weights
            optimizer_D.zero_grad()
            if opt.fp16:
                with amp.scale_loss(loss_D, optimizer_D) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss_D.backward()
            optimizer_D.step()

            ############## Display results and errors ##########
            ### print out errors
            if total_steps % opt.print_freq == print_delta:
                errors = {
                    k: v.data.item() if not isinstance(v, int) else v
                    for k, v in loss_dict.items()
                }
                t = (time.time() - iter_start_time) / opt.print_freq
                visualizer.print_current_errors(epoch, epoch_iter, errors, t)
                visualizer.plot_current_errors(errors, total_steps)
                #call(["nvidia-smi", "--format=csv", "--query-gpu=memory.used,memory.free"])

            ### display output images
            if save_fake:
                visuals = OrderedDict([
                    ('input_label',
                     util.tensor2label(data['label'][0], opt.label_nc)),
                    ('synthesized_image', util.tensor2im(generated.data[0])),
                    ('real_image', util.tensor2im(data['image'][0]))
                ])
                visualizer.display_current_results(visuals, epoch, total_steps)

            ### save latest model
            if total_steps % opt.save_latest_freq == save_delta:
                print('saving the latest model (epoch %d, total_steps %d)' %
                      (epoch, total_steps))
                model.module.save('latest')
                np.savetxt(iter_path, (epoch, epoch_iter),
                           delimiter=',',
                           fmt='%d')

            if epoch_iter >= dataset_size:
                break

        # end of epoch
        iter_end_time = time.time()
        print('End of epoch %d / %d \t Time Taken: %d sec' %
              (epoch, opt.niter + opt.niter_decay,
               time.time() - epoch_start_time))

        ### save model for this epoch
        if epoch % opt.save_epoch_freq == 0:
            print('saving the model at the end of epoch %d, iters %d' %
                  (epoch, total_steps))
            model.module.save('latest')
            model.module.save(epoch)
            np.savetxt(iter_path, (epoch + 1, 0), delimiter=',', fmt='%d')

        ### instead of only training the local enhancer, train the entire network after certain iterations
        if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global):
            model.module.update_fixed_params()

        ### linearly decay learning rate after certain iterations
        if epoch > opt.niter:
            model.module.update_learning_rate()
def main():
    iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')
    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('#training images = %d' % dataset_size)

    start_epoch, epoch_iter = 1, 0
    total_steps = (start_epoch - 1) * dataset_size + epoch_iter
    display_delta = total_steps % opt.display_freq
    print_delta = total_steps % opt.print_freq
    save_delta = total_steps % opt.save_latest_freq

    model = create_model(opt)
    model = model.cuda()
    visualizer = Visualizer(opt)

    for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
        if epoch != start_epoch:
            epoch_iter = epoch_iter % dataset_size
        for i, data in enumerate(dataset, start=epoch_iter):
            iter_start_time = time.time()
            total_steps += opt.batchSize
            epoch_iter += opt.batchSize

            # whether to collect output images
            save_fake = total_steps % opt.display_freq == display_delta

            ############## Forward Pass ######################
            losses, generated = model(Variable(data['label']), Variable(data['inst']),
                                      Variable(data['image']), Variable(data['feat']), infer=save_fake)

            # sum per device losses
            losses = [torch.mean(x) if not isinstance(x, int) else x for x in losses]
            loss_dict = dict(zip(model.loss_names, losses))

            # calculate final loss scalar
            loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5
            loss_G = loss_dict['G_GAN'] + loss_dict.get('G_GAN_Feat', 0) + loss_dict.get('G_VGG', 0)

            ############### Backward Pass ####################
            # update generator weights
            model.optimizer_G.zero_grad()
            loss_G.backward()
            model.optimizer_G.step()

            # update discriminator weights
            model.optimizer_D.zero_grad()
            loss_D.backward()
            model.optimizer_D.step()


            ############## Display results and errors ##########
            ### print out errors
            if total_steps % opt.print_freq == print_delta:
                errors = {k: v.data if not isinstance(v, int) else v for k, v in loss_dict.items()}  # CHANGE: removed [0] after v.data
                t = (time.time() - iter_start_time) / opt.batchSize
                visualizer.print_current_errors(epoch, epoch_iter, errors, t)
                visualizer.plot_current_errors(errors, total_steps)

            ### display output images
            if save_fake:
                visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)),
                                       ('synthesized_image', util.tensor2im(generated.data[0])),
                                       ('real_image', util.tensor2im(data['image'][0]))])
                visualizer.display_current_results(visuals, epoch, total_steps)

            ### save latest model
            if total_steps % opt.save_latest_freq == save_delta:
                print('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps))
                model.save('latest')
                np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d')

            if epoch_iter >= dataset_size:
                break

        # end of epoch
        print('End of epoch %d / %d \t Time Taken: %d sec' %
              (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))

        ### save model for this epoch
        if epoch % opt.save_epoch_freq == 0:
            print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps))
            model.save('latest')
            model.save(epoch)
            np.savetxt(iter_path, (epoch + 1, 0), delimiter=',', fmt='%d')

        ### instead of only training the local enhancer, train the entire network after certain iterations
        if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global):
            model.update_fixed_params()

        ### linearly decay learning rate after certain iterations
        if epoch > opt.niter:
            model.update_learning_rate()

    torch.cuda.empty_cache()
Ejemplo n.º 26
0
def train(opt):

    iter_path = os.path.join(opt.checkpoints_dir, 'iter.txt')

    if opt.continue_train:
        try:
            start_epoch, epoch_iter = np.loadtxt(iter_path,
                                                 delimiter=',',
                                                 dtype=int)
        except:
            start_epoch, epoch_iter = 1, 0
        print('Resuming from epoch %d at iteration %d' %
              (start_epoch, epoch_iter))
    else:
        start_epoch, epoch_iter = 1, 0

    if opt.debug:
        opt.display_freq = 1
        opt.print_freq = 1
        opt.niter = 1
        opt.niter_decay = 0
        opt.max_dataset_size = 10

    if opt.mode == 'meta-train':
        print("--picked k random images")
        opt.txtfile_img, opt.txtfile_label = create_k_txtfile_rand(
            opt.txtfile_img, opt.txtfile_label,
            opt.checkpoints_dir + "rand_k_img.txt",
            opt.checkpoints_dir + "rand_k_label.txt", opt.k)

    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('#training images = %d' % dataset_size)

    model = create_model(opt)

    visualizer = Visualizer(opt)

    total_steps = (start_epoch - 1) * dataset_size + epoch_iter

    display_delta = total_steps % opt.display_freq
    print_delta = total_steps % opt.print_freq
    save_delta = total_steps % opt.save_latest_freq

    for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
        if epoch != start_epoch:
            epoch_iter = epoch_iter % dataset_size
        for i, data in enumerate(dataset, start=epoch_iter):
            iter_start_time = time.time()
            total_steps += opt.batchSize
            epoch_iter += opt.batchSize

            # whether to collect output images
            save_fake = total_steps % opt.display_freq == display_delta

            ############## Forward Pass ######################
            losses, generated = model(Variable(data['label']),
                                      Variable(data['inst']),
                                      Variable(data['image']),
                                      Variable(data['feat']),
                                      infer=save_fake)

            # sum per device losses
            losses = [
                torch.mean(x) if not isinstance(x, int) else x for x in losses
            ]
            loss_dict = dict(zip(model.module.loss_names, losses))

            # calculate final loss scalar
            loss_D = (loss_dict.get('D_fake', 0) +
                      loss_dict.get('D_real', 0)) * 0.5
            loss_G = loss_dict.get('G_GAN', 0) + loss_dict.get(
                'G_GAN_Feat', 0) + loss_dict.get('G_VGG', 0)

            total_loss_dict = {'loss_D': loss_D, 'loss_G': loss_G}

            ############### Backward Pass ####################
            # update generator weights
            model.module.optimizer_G.zero_grad()
            loss_G.backward()
            model.module.optimizer_G.step()

            # update discriminator weights

            if type(loss_D) != float:
                model.module.optimizer_D.zero_grad()
                loss_D.backward()
                model.module.optimizer_D.step()

            #call(["nvidia-smi", "--format=csv", "--query-gpu=memory.used,memory.free"])

            ############## Display results and errors ##########
            ### print out errors
            if total_steps % opt.print_freq == print_delta:
                errors = {
                    k: v.item() if not isinstance(v, int) else v
                    for k, v in loss_dict.items()
                }
                t = (time.time() - iter_start_time) / opt.batchSize
                visualizer.print_current_errors(epoch, epoch_iter, errors, t,
                                                total_loss_dict)
                visualizer.plot_current_errors(errors, total_steps,
                                               total_loss_dict)

            ### display output images
            if save_fake:
                visuals = OrderedDict([
                    ('input_label',
                     util.tensor2label(data['label'][0], opt.label_nc)),
                    ('synthesized_image', util.tensor2im(generated.data[0])),
                    ('real_image', util.tensor2im(data['image'][0]))
                ])
                visualizer.display_current_results(visuals, epoch, total_steps)

            ### save latest model
            if total_steps % opt.save_latest_freq == save_delta:
                print('saving the latest model (epoch %d, total_steps %d)' %
                      (epoch, total_steps))
                model.module.save('latest')
                if opt.mode == 'test_checkpoints':
                    model.module.save(str(epoch) + '_' + str(total_steps))
                np.savetxt(iter_path, (epoch, epoch_iter),
                           delimiter=',',
                           fmt='%d')

            if epoch_iter >= dataset_size:
                break

        # end of epoch
        print('End of epoch %d / %d \t Time Taken: %d sec' %
              (epoch, opt.niter + opt.niter_decay,
               time.time() - epoch_start_time))
        ### save model for this epoch
        if epoch % opt.save_epoch_freq == 0:
            print('saving the model at the end of epoch %d, iters %d' %
                  (epoch, total_steps))
            model.module.save('latest')
            model.module.save(epoch)
            np.savetxt(iter_path, (epoch + 1, 0), delimiter=',', fmt='%d')

        ### instead of only training the local enhancer, train the entire network after certain iterations
        if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global):
            model.module.update_fixed_params()

        ### linearly decay learning rate after certain iterations
        if epoch > opt.niter:
            model.module.update_learning_rate()
    if opt.mode != "meta-train":
        model.module.save('latest')
    exit_cuda()
    return model, loss_dict, total_loss_dict
Ejemplo n.º 27
0
def main():
    opt = TrainOptions().parse()

    iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')
    if opt.continue_train:
        try:
            start_epoch, epoch_iter = np.loadtxt(iter_path,
                                                 delimiter=',',
                                                 dtype=int)
        except:
            start_epoch, epoch_iter = 1, 0
        print('Resuming from epoch %d at iteration %d' %
              (start_epoch, epoch_iter))
    else:
        start_epoch, epoch_iter = 1, 0

    if opt.debug:
        opt.display_freq = 1
        opt.print_freq = 1
        opt.niter = 1
        opt.niter_decay = 0
        opt.max_dataset_size = 10

    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('#training images = %d' % dataset_size)

    model = create_model(opt)
    visualizer = Visualizer(opt)

    total_steps = (start_epoch - 1) * dataset_size + epoch_iter

    for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
        if epoch != start_epoch:
            epoch_iter = epoch_iter % dataset_size
        for i, data in enumerate(dataset, start=epoch_iter):
            iter_start_time = time.time()
            total_steps += opt.batchSize
            epoch_iter += opt.batchSize

            # whether to collect output images
            save_fake = total_steps % opt.display_freq == 0

            losses, real, label, generated, res, comp, up = model(
                Variable(data['label']),
                Variable(data['image']),
                Variable(data['ds']),
                infer=save_fake)

            # sum per device losses
            losses = [
                torch.mean(x) if not isinstance(x, int) else x for x in losses
            ]
            loss_dict = dict(zip(model.module.loss_names, losses))

            # calculate final loss scalar
            loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5
            loss_G = loss_dict['G_GAN'] + loss_dict['G_GAN_Feat'] + loss_dict[
                'G_VGG'] + loss_dict['G_DIS'] + loss_dict['G_SSIM']

            ############### Backward Pass ####################
            # update generator weights
            model.module.optimizer_G.zero_grad()
            loss_G.backward()
            model.module.optimizer_G.step()

            # update discriminator weights
            model.module.optimizer_D.zero_grad()
            loss_D.backward()
            model.module.optimizer_D.step()

            ############## Display results and errors ##########
            ### print out errors
            if total_steps % opt.print_freq == 0:
                errors = {
                    k: v.data[0] if not isinstance(v, int) else v
                    for k, v in loss_dict.items()
                }
                #errors['psp_loss'] = psp_train_loss.avg
                t = (time.time() - iter_start_time) / opt.batchSize
                visualizer.print_current_errors(epoch, epoch_iter, errors, t)
                visualizer.plot_current_errors(errors, total_steps)

            ### display output images
            if save_fake:
                #for i in range(opt.batchSize):
                i = 0
                visuals = OrderedDict([
                    ('input_label',
                     util.tensor2label(data['label'][i], opt.label_nc)),
                    ('fine_image', util.tensor2im(res.data[i])),
                    ('comp_image', util.tensor2im(comp.data[i])),
                    ('up_image', util.tensor2im(up.data[i])),
                    ('synthesized_image', util.tensor2im(generated.data[i])),
                    ('real_image', util.tensor2im(data['image'][i]))
                ])
                visualizer.display_current_results(visuals, epoch, epoch_iter)

            ### save latest model
            if total_steps % opt.save_latest_freq == 0:
                print('saving the latest model (epoch %d, total_steps %d)' %
                      (epoch, total_steps))
                model.module.save('latest')
                np.savetxt(iter_path, (epoch, epoch_iter),
                           delimiter=',',
                           fmt='%d')

        # end of epoch
        print('End of epoch %d / %d \t Time Taken: %d sec' %
              (epoch, opt.niter + opt.niter_decay,
               time.time() - epoch_start_time))

        ### save model for this epoch
        if epoch % opt.save_epoch_freq == 0:
            print('saving the model at the end of epoch %d, iters %d' %
                  (epoch, total_steps))
            model.module.save('latest')
            model.module.save(epoch)
            np.savetxt(iter_path, (epoch + 1, 0), delimiter=',', fmt='%d')

        ### instead of only training the local enhancer, train the entire network after certain iterations
        if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global):
            model.module.update_fixed_params()

        ### linearly decay learning rate after certain iterations
        if epoch > opt.niter:
            model.module.update_learning_rate()
Ejemplo n.º 28
0
    label_encodings, num_labels = model.encode_features(Variable(
        data['input']))
    # convert tensor to numpy
    label_encodings = label_encodings.data.cpu().numpy()
    for label in label_dict:
        fname_feature_dict['_'.join([fname, str(label)])] = \
                            label_encodings[:, label_dict[label]*opt.nz: (label_dict[label]+1)*opt.nz]

    fname_feature_dict['_'.join([fname,
                                 str(0)])] = label_encodings[:, -1 * opt.nz:]

    if sanity_check:
        # Sanity check the label_encodings
        generated = model.generate_from_random(label_encodings)
        util.save_image(
            util.tensor2label(data['input'][0], opt.output_nc,
                              normalize=False),
            os.path.join(img_dir, 'input_label_%s.jpg' % (fname)))
        util.save_image(
            util.tensor2label(generated.data[0], opt.output_nc,
                              normalize=True),
            os.path.join(img_dir, 'synthesized_label_%s.jpg' % (fname)))

dir_path = ROOT_DIR + '/separate_vae/results/Lab/demo/'
if not os.path.exists(dir_path):
    print(f"Creating directory {dir_path}")
    os.makedirs(dir_path)
if demo:
    save_name = os.path.join(dir_path, '%s_shape_codes.p' % opt.phase)
else:
    save_name = os.path.join(img_dir, '%s_shape_codes.p' % opt.phase)
with open(save_name, 'wb') as writefile:
Ejemplo n.º 29
0
    if opt.export_onnx:
        print("Exporting to ONNX: ", opt.export_onnx)
        assert opt.export_onnx.endswith(
            "onnx"), "Export model file should end with .onnx"
        torch.onnx.export(model, [data['label'], data['inst']],
                          opt.export_onnx,
                          verbose=True)
        exit(0)
    minibatch = 1
    if opt.engine:
        generated = run_trt_engine(opt.engine, minibatch,
                                   [data['label'], data['inst']])
    elif opt.onnx:
        generated = run_onnx(opt.onnx, opt.data_type, minibatch,
                             [data['label'], data['inst']])
    else:
        generated = model.inference(data['A'], data['B'], data['B2'])

    visuals = OrderedDict([('input_label', util.tensor2label(data['A'][0], 0)),
                           ('real_image', util.tensor2im(data['A2'][0])),
                           ('synthesized_image',
                            util.tensor2im(generated.data[0])),
                           ('B', util.tensor2label(data['B'][0], 0)),
                           ('B2', util.tensor2im(data['B2'][0]))])
    img_path = data['path']
    img_path[0] = str(i)
    print('process image... %s' % img_path)
    visualizer.save_images(webpage, visuals, img_path)

webpage.save()
Ejemplo n.º 30
0
        if total_steps % opt.print_freq == print_delta:
            errors = {
                k: v.data[0] if not isinstance(v, int) else v
                for k, v in loss_dict.items()
            }
            errors['loss_G'] = loss_G
            errors['loss_D'] = loss_D
            errors['loss_SD'] = loss_SD
            t = (time.time() - iter_start_time) / opt.batchSize
            visualizer.print_current_errors(epoch, epoch_iter, errors, t)
            visualizer.plot_current_errors(errors, total_steps)

        ### display output images
        if save_fake:
            visuals = OrderedDict([
                ('input_label', util.tensor2label(data['A'][0], 256)),
                ('real_image', util.tensor2im(data['A2'][0])),
                ('synthesized_image', util.tensor2im(generated.data[0])),
                ('B', util.tensor2label(data['B'][0], 256)),
                ('B2', util.tensor2im(data['B2'][0]))
            ])
            visualizer.display_current_results2(visuals, epoch, total_steps)

        ### save latest model
        if total_steps % opt.save_latest_freq == save_delta:
            print('saving the latest model (epoch %d, total_steps %d)' %
                  (epoch, total_steps))
            model.module.save('latest')
            np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d')

        if opt.use_iter_decay and total_steps > opt.niter_iter:
Ejemplo n.º 31
0
    if opt.data_type == 16:
        data['label'] = data['label'].half()
        data['inst']  = data['inst'].half()
    elif opt.data_type == 8:
        data['label'] = data['label'].uint8()
        data['inst']  = data['inst'].uint8()
    if opt.export_onnx:
        print ("Exporting to ONNX: ", opt.export_onnx)
        assert opt.export_onnx.endswith("onnx"), "Export model file should end with .onnx"
        torch.onnx.export(model, [data['label'], data['inst']],
                          opt.export_onnx, verbose=True)
        exit(0)
    minibatch = 1 
    if opt.engine:
        generated = run_trt_engine(opt.engine, minibatch, [data['label'], data['inst']])
    elif opt.onnx:
        generated = run_onnx(opt.onnx, opt.data_type, minibatch, [data['label'], data['inst']])
    else:        
        generated = model.inference(data['label'], data['inst'], data['image'])
        
    visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)),
                           ('synthesized_image', util.tensor2im(generated.data[0]))])
    img_path = data['path']
    print('process image... %s' % img_path)
    visualizer.save_images(webpage, visuals, img_path)

    print("----{}s seconds----".format(time.time() - start_time))


webpage.save()
Ejemplo n.º 32
0
                ans.append(path)
    f_w = open(target_path, 'w')
    f_w.writelines(ans)


opt = TrainOptions().parse()
opt.phase = 'val'
write_temp(opt, "temp")
opt.phase = "temp"
opt.serial_batches = True

data_loader = CreatePoseConDataLoader(opt)
dataset = data_loader.load_data()
dataset_size = len(data_loader)

visualizer = Visualizer(opt)

total_steps = 0  # (start_epoch-1) * dataset_size + epoch_iter

display_delta = total_steps % opt.display_freq
print_delta = total_steps % opt.print_freq
save_delta = total_steps % opt.save_latest_freq

for i, data in enumerate(dataset):
    if (i % 100 == 0):
        print((i, dataset_size))
    visuals = OrderedDict([('input_label',
                            util.tensor2label(data['A'][0][3:6, :, :], 0)),
                           ('real_image', util.tensor2im(data['A2'][0]))])
    visualizer.display_current_results2(visuals, 0, i)