Ejemplo n.º 1
0
    def __init__(self, config):
        super(ImGANTrainer, self).__init__()

        self.netG_ab = Generator(config)
        self.netG_ba = Generator(config)
        self.netD_ab = PatchD(config['input_nc'], config['ndf'])
        self.netD_ba = PatchD(config['input_nc'], config['ndf'])

        self.optimizer_g = torch.optim.Adam(itertools.chain(
            self.netG_ab.parameters(), self.netG_ba.parameters()),
                                            lr=config['lr'],
                                            betas=(config['beta1'], 0.999))
        self.optimizer_d = torch.optim.Adam(itertools.chain(
            self.netD_ab.parameters(), self.netD_ba.parameters()),
                                            lr=config['lr'],
                                            betas=(config['beta1'], 0.999))
        # criterion
        self.criteritionGAN = nn.BCELoss()
        self.criteritioL1 = nn.L1Loss()
        self.criteritiommd = MMDLoss()
        # labels
        self.real_label = 1.
        self.fake_label = 0.

        # losses
        self.loss_names = [
            'loss_D', 'loss_G', 'loss_cycle_aba', 'loss_cycle_bab', 'loss_mmd'
        ]
    def __init__(self, config):
        super(Trainer, self).__init__()
        self.config = config
        self.use_cuda = self.config['cuda']
        self.device_ids = self.config['gpu_ids']

        self.netG = Generator(self.config['netG'], self.use_cuda,
                              self.device_ids)
        self.localD = LocalDis(self.config['netD'], self.use_cuda,
                               self.device_ids)
        self.globalD = GlobalDis(self.config['netD'], self.use_cuda,
                                 self.device_ids)

        self.optimizer_g = torch.optim.Adam(self.netG.parameters(),
                                            lr=self.config['lr'],
                                            betas=(self.config['beta1'],
                                                   self.config['beta2']))
        d_params = list(self.localD.parameters()) + list(
            self.globalD.parameters())
        self.optimizer_d = torch.optim.Adam(d_params,
                                            lr=config['lr'],
                                            betas=(self.config['beta1'],
                                                   self.config['beta2']))
        if self.use_cuda:
            self.netG.to(self.device_ids[0])
            self.localD.to(self.device_ids[0])
            self.globalD.to(self.device_ids[0])
Ejemplo n.º 3
0
def main(args):

    device = torch.device("cuda:0")

    G = Generator().to(device)
    G = nn.DataParallel(G)
    G.load_state_dict(torch.load(args.model_path))

    with torch.no_grad():
        G.eval()

        batch_size = args.batch_size
        n_epoch = args.n // batch_size + 1

        for epoch in tqdm(range(n_epoch)):

            bs = min(batch_size, args.n - epoch * batch_size)
            za = torch.randn(bs, args.d_za, 1, 1, 1).to(device)
            zm = torch.randn(bs, args.d_zm, 1, 1, 1).to(device)

            vid_fake = G(za, zm)

            vid_fake = vid_fake.transpose(2, 1)  # bs x 16 x 3 x 64 x 64
            vid_fake = ((vid_fake - vid_fake.min()) /
                        (vid_fake.max() - vid_fake.min())).data

            # save into videos
            save_videos(args.gen_path, vid_fake, epoch, bs)

    return
Ejemplo n.º 4
0
    def __init__(self, opts, nc_in=5, nc_out=3, d_s_args={}, d_t_args={}):
        super().__init__()
        self.d_t_args = {
            "nf": 32,
            "use_sigmoid": True,
            "norm": 'SN'
        }  # default values
        for key, value in d_t_args.items():
            # overwrite default values if provided
            self.d_t_args[key] = value

        self.d_s_args = {
            "nf": 32,
            "use_sigmoid": True,
            "norm": 'SN'
        }  # default values
        for key, value in d_s_args.items():
            # overwrite default values if provided
            self.d_s_args[key] = value

        nf = opts['nf']
        norm = opts['norm']
        use_bias = opts['bias']

        # warning: if 2d convolution is used in generator, settings (e.g. stride,
        # kernal_size, padding) on the temporal axis will be discarded
        self.conv_by = opts['conv_by'] if 'conv_by' in opts else '3d'
        self.conv_type = opts['conv_type'] if 'conv_type' in opts else 'gated'

        self.use_refine = opts['use_refine'] if 'use_refine' in opts else False
        use_skip_connection = opts.get('use_skip_connection', False)

        self.opts = opts

        ######################
        # Convolution layers #
        ######################
        self.generator = Generator(nc_in,
                                   nc_out,
                                   nf,
                                   use_bias,
                                   norm,
                                   self.conv_by,
                                   self.conv_type,
                                   use_refine=self.use_refine,
                                   use_skip_connection=use_skip_connection)

        #################
        # Discriminator #
        #################

        if 'spatial_discriminator' not in opts or opts['spatial_discriminator']:
            self.spatial_discriminator = SNTemporalPatchGANDiscriminator(
                nc_in=5, conv_type='2d', **self.d_s_args)
        if 'temporal_discriminator' not in opts or opts[
                'temporal_discriminator']:
            self.temporal_discriminator = SNTemporalPatchGANDiscriminator(
                nc_in=5, **self.d_t_args)
Ejemplo n.º 5
0
def main(args):

    # write into tensorboard
    log_path = os.path.join('demos', args.dataset + '/log')
    vid_path = os.path.join('demos', args.dataset + '/vids')

    os.makedirs(log_path, exist_ok=True)
    os.makedirs(vid_path, exist_ok=True)
    writer = SummaryWriter(log_path)

    device = torch.device("cuda:0")

    G = Generator(args.dim_z, args.dim_a, args.nclasses, args.ch).to(device)
    G = nn.DataParallel(G)
    G.load_state_dict(torch.load(args.model_path))

    transform = torchvision.transforms.Compose([
        transforms.Resize((args.img_size, args.img_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])

    dataset = MUG_test(args.data_path, transform=transform)

    dataloader = torch.utils.data.DataLoader(dataset=dataset,
                                             batch_size=args.batch_size,
                                             num_workers=args.num_workers,
                                             shuffle=False,
                                             pin_memory=True)

    with torch.no_grad():

        G.eval()

        img = next(iter(dataloader))

        bs = img.size(0)
        nclasses = args.nclasses

        z = torch.randn(bs, args.dim_z).to(device)

        for i in range(nclasses):
            y = torch.zeros(bs, nclasses).to(device)
            y[:, i] = 1.0
            vid_gen = G(img, z, y)

            vid_gen = vid_gen.transpose(2, 1)
            vid_gen = ((vid_gen - vid_gen.min()) /
                       (vid_gen.max() - vid_gen.min())).data

            writer.add_video(tag='vid_cat_%d' % i, vid_tensor=vid_gen)
            writer.flush()

            # save videos
            print('==> saving videos')
            save_videos(vid_path, vid_gen, bs, i)
Ejemplo n.º 6
0
def main():

    args = cfg.parse_args()

    # write into tensorboard
    log_path = os.path.join(args.demo_path, args.demo_name + '/log')
    vid_path = os.path.join(args.demo_path, args.demo_name + '/vids')

    if not os.path.exists(log_path) and not os.path.exists(vid_path):
        os.makedirs(log_path)
        os.makedirs(vid_path)
    writer = SummaryWriter(log_path)

    device = torch.device("cuda:0")

    G = Generator().to(device)
    G = nn.DataParallel(G)
    G.load_state_dict(torch.load(args.model_path))

    with torch.no_grad():
        G.eval()

        za = torch.randn(args.n_za_test, args.d_za, 1, 1, 1).to(device)
        zm = torch.randn(args.n_zm_test, args.d_zm, 1, 1, 1).to(device)

        n_za = za.size(0)
        n_zm = zm.size(0)
        za = za.unsqueeze(1).repeat(1, n_zm, 1, 1, 1, 1).contiguous().view(
            n_za * n_zm, -1, 1, 1, 1)
        zm = zm.repeat(n_za, 1, 1, 1, 1)

        vid_fake = G(za, zm)

        vid_fake = vid_fake.transpose(2, 1)  # bs x 16 x 3 x 64 x 64
        vid_fake = ((vid_fake - vid_fake.min()) /
                    (vid_fake.max() - vid_fake.min())).data

        writer.add_video(tag='generated_videos',
                         global_step=1,
                         vid_tensor=vid_fake)
        writer.flush()

        # save into videos
        print('==> saving videos...')
        save_videos(vid_path, vid_fake, n_za, n_zm)

    return
Ejemplo n.º 7
0
def loadGenerator(args):
    config = get_config(args.g_config)

    # CUDA configuration
    cuda = config['cuda']
    device_ids = config['gpu_ids']
    if cuda:
        os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(
            str(i) for i in device_ids)
        device_ids = list(range(len(device_ids)))
        config['gpu_ids'] = device_ids
        cudnn.benchmark = True

    # Set checkpoint path
    if not args.checkpoint_path:
        checkpoint_path = os.path.join(
            'checkpoints', config['dataset_name'],
            config['mask_type'] + '_' + config['expname'])
    else:
        checkpoint_path = args.checkpoint_path

    # Define the trainer
    netG = Generator(config['netG'], cuda, device_ids).cuda()
    # Resume weight
    last_model_name = get_model_list(checkpoint_path,
                                     "gen",
                                     iteration=args.iter)
    model_iteration = int(last_model_name[-11:-3])
    netG.load_state_dict(torch.load(last_model_name))

    print("Configuration: {}".format(config))
    print("Resume from {} at iteration {}".format(checkpoint_path,
                                                  model_iteration))

    if cuda:
        netG = nn.parallel.DataParallel(netG, device_ids=device_ids)

    return netG
Ejemplo n.º 8
0
def main():

	args = cfg.parse_args()

	# write into tensorboard
	log_path = os.path.join(args.demo_path, args.demo_name + '/log')
	vid_path = os.path.join(args.demo_path, args.demo_name + '/vids')
	if not os.path.exists(log_path) and not os.path.exists(vid_path):
		os.makedirs(log_path)
		os.makedirs(vid_path)
	writer = SummaryWriter(log_path)

	device = torch.device("cuda:0")

	G = Generator().to(device)
	G = nn.DataParallel(G)
	G.load_state_dict(torch.load(args.model_path))

	with torch.no_grad():
		G.eval()

		za = torch.randn(args.n_za_test, args.d_za, 1, 1, 1).to(device) # appearance

		# generating frames from [16, 20, 24, 28, 32, 36, 40, 44, 48]
		for i in range(9):
			zm = torch.randn(args.n_zm_test, args.d_zm, (i+1), 1, 1).to(device) # 16+i*4
			vid_fake = G(za, zm)
			vid_fake = vid_fake.transpose(2,1)
			vid_fake = ((vid_fake - vid_fake.min()) / (vid_fake.max() - vid_fake.min())).data
			writer.add_video(tag='generated_videos_%dframes'%(16+i*4), global_step=1, vid_tensor=vid_fake)
			writer.flush()

			print('saving videos')
			save_videos(vid_path, vid_fake, args.n_za_test, (16+i*4))

	return
Ejemplo n.º 9
0
def main():

    args = cfg.parse_args()
    torch.cuda.manual_seed(args.random_seed)

    print(args)

    # create logging folder
    log_path = os.path.join(args.save_path, args.exp_name + '/log')
    model_path = os.path.join(args.save_path, args.exp_name + '/models')
    os.makedirs(log_path, exist_ok=True)
    os.makedirs(model_path, exist_ok=True)
    writer = SummaryWriter(log_path)  # tensorboard

    # load model
    print('==> loading models')
    device = torch.device("cuda:0")

    G = Generator(args.dim_z, args.dim_a, args.nclasses, args.ch).to(device)
    VD = VideoDiscriminator(args.nclasses, args.ch).to(device)
    ID = ImageDiscriminator(args.ch).to(device)

    G = nn.DataParallel(G)
    VD = nn.DataParallel(VD)
    ID = nn.DataParallel(ID)

    # optimizer
    optimizer_G = torch.optim.Adam(G.parameters(), args.g_lr, (0.5, 0.999))
    optimizer_VD = torch.optim.Adam(VD.parameters(), args.d_lr, (0.5, 0.999))
    optimizer_ID = torch.optim.Adam(ID.parameters(), args.d_lr, (0.5, 0.999))

    # loss
    criterion_gan = nn.BCEWithLogitsLoss().to(device)
    criterion_l1 = nn.L1Loss().to(device)

    # prepare dataset
    print('==> preparing dataset')
    transform = torchvision.transforms.Compose([
        transforms_vid.ClipResize((args.img_size, args.img_size)),
        transforms_vid.ClipToTensor(),
        transforms_vid.ClipNormalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])

    transform_test = torchvision.transforms.Compose([
        transforms.Resize((args.img_size, args.img_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])

    if args.dataset == 'mug':
        dataset_train = MUG('train', args.data_path, transform=transform)
        dataset_val = MUG('val', args.data_path, transform=transform)
        dataset_test = MUG_test(args.data_path, transform=transform_test)
    else:
        raise NotImplementedError

    dataloader_train = torch.utils.data.DataLoader(
        dataset=dataset_train,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        shuffle=True,
        pin_memory=True,
        drop_last=True)

    dataloader_val = torch.utils.data.DataLoader(dataset=dataset_val,
                                                 batch_size=args.batch_size,
                                                 num_workers=args.num_workers,
                                                 shuffle=False,
                                                 pin_memory=True)

    dataloader_test = torch.utils.data.DataLoader(
        dataset=dataset_test,
        batch_size=args.batch_size_test,
        num_workers=args.num_workers,
        shuffle=False,
        pin_memory=True)

    print('==> start training')
    for epoch in range(args.max_epoch):
        train(args, epoch, G, VD, ID, optimizer_G, optimizer_VD, optimizer_ID,
              criterion_gan, criterion_l1, dataloader_train, writer, device)

        if epoch % args.val_freq == 0:
            val(args, epoch, G, criterion_l1, dataloader_val, device, writer)
            test(args, epoch, G, dataloader_test, device, writer)

        if epoch % args.save_freq == 0:
            torch.save(G.state_dict(),
                       os.path.join(model_path, 'G_%d.pth' % (epoch)))
            torch.save(VD.state_dict(),
                       os.path.join(model_path, 'VD_%d.pth' % (epoch)))
            torch.save(ID.state_dict(),
                       os.path.join(model_path, 'ID_%d.pth' % (epoch)))

    return
def main():
    args = parser.parse_args()
    config = get_config(args.config)

    # CUDA configuration
    cuda = config['cuda']
    device_ids = config['gpu_ids']
    if cuda:
        os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(
            str(i) for i in device_ids)
        device_ids = list(range(len(device_ids)))
        config['gpu_ids'] = device_ids
        cudnn.benchmark = True

    print("Arguments: {}".format(args))

    # Set random seed
    if args.seed is None:
        args.seed = random.randint(1, 10000)
    print("Random seed: {}".format(args.seed))
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    if cuda:
        torch.cuda.manual_seed_all(args.seed)

    print("Configuration: {}".format(config))

    try:  # for unexpected error logging
        with torch.no_grad():   # enter no grad context
            if is_image_file(args.image):
                if args.mask and is_image_file(args.mask):
                    # Test a single masked image with a given mask
                    x = default_loader(args.image)
                    mask = default_loader(args.mask)
                    x = transforms.Resize(config['image_shape'][:-1])(x)
                    x = transforms.CenterCrop(config['image_shape'][:-1])(x)
                    mask = transforms.Resize(config['image_shape'][:-1])(mask)
                    mask = transforms.CenterCrop(
                        config['image_shape'][:-1])(mask)
                    x = transforms.ToTensor()(x)
                    mask = transforms.ToTensor()(mask)[0].unsqueeze(dim=0)
                    x = normalize(x)
                    x = x * (1. - mask)
                    x = x.unsqueeze(dim=0)
                    mask = mask.unsqueeze(dim=0)
                elif args.mask:
                    raise TypeError(
                        "{} is not an image file.".format(args.mask))
                else:
                    # Test a single ground-truth image with a random mask
                    ground_truth = default_loader(args.image)
                    ground_truth = transforms.Resize(
                        config['image_shape'][:-1])(ground_truth)
                    ground_truth = transforms.CenterCrop(
                        config['image_shape'][:-1])(ground_truth)
                    ground_truth = transforms.ToTensor()(ground_truth)
                    ground_truth = normalize(ground_truth)
                    ground_truth = ground_truth.unsqueeze(dim=0)
                    bboxes = random_bbox(
                        config, batch_size=ground_truth.size(0))
                    x, mask = mask_image(ground_truth, bboxes, config)

                # Set checkpoint path
                if not args.checkpoint_path:
                    checkpoint_path = os.path.join('checkpoints',
                                                   config['dataset_name'],
                                                   config['mask_type'] + '_' + config['expname'])
                else:
                    checkpoint_path = args.checkpoint_path

                # Define the trainer
                netG = Generator(config['netG'], cuda, device_ids)
                # Resume weight
                last_model_name = get_model_list(
                    checkpoint_path, "gen", iteration=args.iter)
                netG.load_state_dict(torch.load(last_model_name))
                model_iteration = int(last_model_name[-11:-3])
                print("Resume from {} at iteration {}".format(
                    checkpoint_path, model_iteration))

                if cuda:
                    netG = nn.parallel.DataParallel(
                        netG, device_ids=device_ids)
                    x = x.cuda()
                    mask = mask.cuda()

                # Inference
                x1, x2, offset_flow = netG(x, mask)
                inpainted_result = x2 * mask + x * (1. - mask)

                vutils.save_image(inpainted_result, args.output,
                                  padding=0, normalize=True)
                print("Saved the inpainted result to {}".format(args.output))
                if args.flow:
                    vutils.save_image(offset_flow, args.flow,
                                      padding=0, normalize=True)
                    print("Saved offset flow to {}".format(args.flow))
            else:
                raise TypeError("{} is not an image file.".format)
        # exit no grad context
    except Exception as e:  # for unexpected error logging
        print("Error: {}".format(e))
        raise e
Ejemplo n.º 11
0
def main():
    args = parser.parse_args()
    config = get_config(args.config)

    # CUDA configuration
    cuda = config['cuda']
    device_ids = config['gpu_ids']
    if cuda:
        os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(
            str(i) for i in device_ids)
        device_ids = list(range(len(device_ids)))
        config['gpu_ids'] = device_ids
        cudnn.benchmark = True

    # Set random seed
    if args.seed is None:
        args.seed = random.randint(1, 10000)
    print("Random seed: {}".format(args.seed))
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    if cuda:
        torch.cuda.manual_seed_all(args.seed)

    chunker = ImageChunker(config['image_shape'][0], config['image_shape'][1],
                           args.overlap)
    try:  # for unexpected error logging
        with torch.no_grad():  # enter no grad context
            if is_image_file(args.image):
                print("Loading image...")
                imgs, masks = [], []
                img_ori = default_loader(args.image)
                img_w, img_h = img_ori.size
                # Load mask txt file
                fname = args.image.replace('.jpg', '.txt')
                bboxes, _ = load_bbox_txt(fname, img_w, img_h)
                mask_ori = create_mask(bboxes, img_w, img_h)
                chunked_images = chunker.dimension_preprocess(
                    np.array(deepcopy(img_ori)))
                chunked_masks = chunker.dimension_preprocess(
                    np.array(deepcopy(mask_ori)))
                for (x, msk) in zip(chunked_images, chunked_masks):
                    x = transforms.ToTensor()(x)
                    mask = transforms.ToTensor()(msk)[0].unsqueeze(dim=0)
                    # x = normalize(x)
                    x = x * (1. - mask)
                    x = x.unsqueeze(dim=0)
                    mask = mask.unsqueeze(dim=0)
                    imgs.append(x)
                    masks.append(mask)

                # Set checkpoint path
                if not args.checkpoint_path:
                    checkpoint_path = os.path.join(
                        'checkpoints', config['dataset_name'],
                        config['mask_type'] + '_' + config['expname'])
                else:
                    checkpoint_path = args.checkpoint_path

                # Define the trainer
                netG = Generator(config['netG'], cuda, device_ids)
                # Resume weight
                last_model_name = get_model_list(checkpoint_path,
                                                 "gen",
                                                 iteration=args.iter)
                netG.load_state_dict(torch.load(last_model_name))
                model_iteration = int(last_model_name[-11:-3])
                print("Resume from {} at iteration {}".format(
                    checkpoint_path, model_iteration))

                pred_imgs = []
                for (x, mask) in zip(imgs, masks):
                    if torch.max(mask) == 1:
                        if cuda:
                            netG = nn.parallel.DataParallel(
                                netG, device_ids=device_ids)
                            x = x.cuda()
                            mask = mask.cuda()

                        # Inference
                        x1, x2, offset_flow = netG(x, mask)
                        inpainted_result = x2 * mask + x * (1. - mask)
                        inpainted_result = inpainted_result.squeeze(
                            dim=0).permute(1, 2, 0).cpu()
                        pred_imgs.append(inpainted_result.numpy())
                    else:
                        pred_imgs.append(
                            x.squeeze(dim=0).permute(1, 2, 0).numpy())

                pred_imgs = np.asarray(pred_imgs, dtype=np.float32)
                reconstructed_image = chunker.dimension_postprocess(
                    pred_imgs, np.array(img_ori))
                # plt.imshow(reconstructed_image); plt.show()
                reconstructed_image = torch.tensor(
                    reconstructed_image).permute(2, 0, 1).unsqueeze(dim=0)
                vutils.save_image(reconstructed_image,
                                  args.output,
                                  padding=0,
                                  normalize=True)
                print("Saved the inpainted result to {}".format(args.output))
                if args.flow:
                    vutils.save_image(offset_flow,
                                      args.flow,
                                      padding=0,
                                      normalize=True)
                    print("Saved offset flow to {}".format(args.flow))
            else:
                raise TypeError("{} is not an image file.".format)
        # exit no grad context
    except Exception as e:  # for unexpected error logging
        print("Error: {}".format(e))
        raise e
Ejemplo n.º 12
0
def main(args):
  # Set random seed for reproducibility
  seed = args.seed
  if(seed is None):
    seed = random.randint(1, 10000) # use if you want new results
  print("Random Seed: ", seed)
  random.seed(seed)
  torch.manual_seed(seed)

  # directories
  saveloc = os.path.join(args.saveloc, args.expname)
  modelpath = os.path.join(args.modelpath, args.modelname)
  if(not os.path.exists(saveloc)):
    os.makedirs(saveloc)

  num_batches = 1 # no. of image batches to generate
  batch_size = 200 # no. of images to generate
  nc = 1 # Number of channels in the training images. For color images this is 3
  nz = 62 # Size of z latent vector (i.e. size of generator input)
  ndc = 10 # latent categorical code
  ncc = 3 # continuous categorical code
  ngf = 64
  fixed_exp = False

  # Number of GPUs available. Use 0 for CPU mode.
  ngpu = 1
  if(ngpu > 0):
    torch.cuda.set_device(0)

  # load model weights
  netG = Generator(ndc+ncc+nz, nc, ngf)
  print('********* Generator **********\n', netG)
  netG.load_state_dict(torch.load(modelpath))
  netG.eval()

  if(ngpu > 0):
    # assign to GPU
    netG = netG.cuda()

  print("Starting Testing Loop...")

  if(fixed_exp):
    z_rand = torch.randn((batch_size, nz, 1, 1))
    z_disc = torch.LongTensor(np.random.randint(ndc, size=(batch_size, 1)))
    z_cont = torch.rand((batch_size, ncc, 1, 1)) * 2 - 1

    # multiple digits plot
    # z_cont2 = torch.tensor(np.tile(np.linspace(-1, 1, 20).reshape(1, -1), reps=(10, 1))).view(batch_size, -1, 1, 1)
    # z_cont1 = torch.rand((batch_size, 1, 1, 1)) * 2 - 1
    # pdb.set_trace()

    # z_disc = torch.LongTensor(np.tile(np.arange(0, 10).reshape(-1,1), reps=[1, batch_size // 10]))
    # z_disc = torch.LongTensor(np.repeat(np.arange(0, 10), repeats=batch_size // 10)).reshape(-1,1)
    # z_cont = torch.tensor(np.tile(np.linspace(-1, 1, 7).reshape(1,-1), reps=[10, 1]))
    # z_disc = 3 * torch.ones((batch_size, 1), dtype=torch.long)
    # z_cont2 = torch.linspace(-6, 6, batch_size).view(batch_size, -1, 1, 1)
    # z_cont1 = torch.rand((batch_size, 1, 1, 1)) * 2 - 1
    
    # z_cont12 = torch.rand((batch_size, 2, 1, 1)) * 2 - 1
    # z_cont3 = torch.linspace(-5, 5, batch_size).view(batch_size, -1, 1, 1)
    # z_cont4 = torch.rand((batch_size, 1, 1, 1)) * 2 - 1

    # z_cont2 = torch.linspace(-4, 4, batch_size).view(batch_size, -1, 1, 1)
    # z_cont1 = torch.rand((batch_size, 1, 1, 1)) * 2 - 1
    # z_cont2 =  torch.tensor(np.tile(np.linspace(-2.5, 2.5, 20).reshape(1, -1), reps=(10, 1))).view(batch_size, -1, 1, 1)
    # z_cont3 = torch.rand((batch_size, 1, 1, 1)) * 2 - 1
    # z_cont3 = torch.rand((batch_size, 1, 1, 1)) * 2 - 1 # torch.linspace(-2, 2, batch_size).view(batch_size, -1, 1, 1)

    # z_cont = torch.cat([
    #   z_cont1.type(torch.float32), 
    #   z_cont2.type(torch.float32), 
    #   z_cont3.type(torch.float32)], 
    #   z_cont4.type(torch.float32)], 
    #   dim=1
    #   )

  for iters in range(num_batches):
    # fake batch
    if(fixed_exp):
      noise, idx = controlled_noise_sample(
        batch_size, ndc, 
        z_random = z_rand,
        # nz=nz, 
        z_categorical=z_disc, 
        # num_discrete = 1,
        z_continuous=z_cont
        # num_continuous=ncc, 
        )
    else:
      noise, idx = noise_sample(1, ndc, ncc, nz, 1)
    if(ngpu > 0):
      noise = noise.cuda()

    fake =  netG(noise)

    # Check how the generator is doing by saving G's output on fixed_noise
    with torch.no_grad():
      fake = netG(noise).detach()
      vutils.save_image(
        fake, 
        os.path.join(saveloc, str(iters)+'.jpg'), 
        nrow=20,
        normalize=True,
        range=(0.0, 1.0)
        )

    with open(os.path.join(saveloc, 'metadata.txt'), 'a') as f:
      for lineno in range(batch_size):
        if(batch_size == 1):
          f.write('C1: {:1.0f}, '.format(idx.item()))
        else:
          f.write('C1: {:1.0f}, '.format(idx[lineno].item()))
        for i, item in enumerate(noise[lineno, nz+ndc:].squeeze()):
          f.write('C'+str(2+i)+': {:1.4f}, '.format(item.item()))
        f.write('\n')

    print('Generated file {}'.format(iters))
Ejemplo n.º 13
0
    def __init__(self,
                 opts,
                 nc_in=5,
                 nc_out=3,
                 d_s_args={},
                 d_t_args={},
                 losses=None):
        super().__init__()
        self.d_t_args = {
            "nf": 32,
            "use_sigmoid": True,
            "norm": 'SN'
        }  # default values
        for key, value in d_t_args.items():
            # overwrite default values if provided
            self.d_t_args[key] = value

        self.d_s_args = {
            "nf": 32,
            "use_sigmoid": True,
            "norm": 'SN'
        }  # default values
        for key, value in d_s_args.items():
            # overwrite default values if provided
            self.d_s_args[key] = value

        nf = opts['nf']
        norm = opts['norm']
        use_bias = opts['bias']

        # warning: if 2d convolution is used in generator, settings (e.g. stride,
        # kernal_size, padding) on the temporal axis will be discarded
        self.conv_by = opts['conv_by'] if 'conv_by' in opts else '3d'
        self.conv_type = opts['conv_type'] if 'conv_type' in opts else 'gated'
        self.flow_tsm = opts['flow_tsm']

        self.use_refine = opts['use_refine'] if 'use_refine' in opts else False
        use_skip_connection = opts.get('use_skip_connection', False)

        self.backbone = opts['backbone'] if 'backbone' in opts else 'unet'

        self.opts = opts

        ######################
        # Convolution layers #
        ######################
        self.generator = Generator(self.backbone,
                                   nc_in,
                                   nc_out,
                                   nf,
                                   use_bias,
                                   norm,
                                   self.conv_by,
                                   self.conv_type,
                                   use_refine=self.use_refine,
                                   use_skip_connection=use_skip_connection,
                                   use_flow_tsm=self.flow_tsm)

        #################
        # Discriminator #
        #################

        if 'spatial_discriminator' not in opts or opts['spatial_discriminator']:
            self.spatial_discriminator = SNTemporalPatchGANDiscriminator(
                nc_in=5, conv_type='2d', **self.d_s_args)
            self.advloss = AdversarialLoss()

        if 'temporal_discriminator' not in opts or opts[
                'temporal_discriminator']:
            self.temporal_discriminator = SNTemporalPatchGANDiscriminator(
                nc_in=5, **self.d_t_args)
            self.advloss = AdversarialLoss()

        #######
        # Vgg #
        #######
        self.vgg = Vgg16(requires_grad=False)

        ########
        # Loss #
        ########
        self.losses = losses
        for key, value in losses.items():
            if value > 0:
                setattr(self, key, loss_nickname_to_module[key]())
Ejemplo n.º 14
0
def train_distributed(config, logger, writer, checkpoint_path):
    
    dist.init_process_group(                                   
        backend='nccl',
#         backend='gloo',
        init_method='env://'
    )  
    
    
    # Find out what GPU on this compute node.
    #
    local_rank = torch.distributed.get_rank()
    
    
    # this is the total # of GPUs across all nodes
    # if using 2 nodes with 4 GPUs each, world size is 8
    #
    world_size = torch.distributed.get_world_size()
    print("### global rank of curr node: {} of {}".format(local_rank, world_size))
    
    
    # For multiprocessing distributed, DistributedDataParallel constructor
    # should always set the single device scope, otherwise,
    # DistributedDataParallel will use all available devices.
    #
    print("local_rank: ", local_rank)
#     dist.barrier()
    torch.cuda.set_device(local_rank)
    
    
    # Define the trainer
    print("Creating models on device: ", local_rank)
    
    
    input_dim = config['netG']['input_dim']
    cnum = config['netG']['ngf']
    use_cuda = True
    gated = config['netG']['gated']
    
    
    # Models
    #
    netG = Generator(config['netG'], use_cuda=True, device=local_rank).cuda()
    netG = torch.nn.parallel.DistributedDataParallel(
        netG,
        device_ids=[local_rank],
        output_device=local_rank,
        find_unused_parameters=True
    )

    
    localD = LocalDis(config['netD'], use_cuda=True, device_id=local_rank).cuda()
    localD = torch.nn.parallel.DistributedDataParallel(
        localD,
        device_ids=[local_rank],
        output_device=local_rank,
        find_unused_parameters=True
    )
    
    
    globalD = GlobalDis(config['netD'], use_cuda=True, device_id=local_rank).cuda()
    globalD = torch.nn.parallel.DistributedDataParallel(
        globalD,
        device_ids=[local_rank],
        output_device=local_rank,
        find_unused_parameters=True
    )
    
    
    if local_rank == 0:
        logger.info("\n{}".format(netG))
        logger.info("\n{}".format(localD))
        logger.info("\n{}".format(globalD))
        
    
    # Optimizers
    #
    optimizer_g = torch.optim.Adam(
        netG.parameters(),
        lr=config['lr'],
        betas=(config['beta1'], config['beta2'])
    )

    
    d_params = list(localD.parameters()) + list(globalD.parameters())
    optimizer_d = torch.optim.Adam(
        d_params,  
        lr=config['lr'],                                    
        betas=(config['beta1'], config['beta2'])                              
    )
    
    
    # Data
    #
    sampler = None
    train_dataset = Dataset(
        data_path=config['train_data_path'],
        with_subfolder=config['data_with_subfolder'],
        image_shape=config['image_shape'],
        random_crop=config['random_crop']
    )
        
    
    sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset,
#             num_replicas=torch.cuda.device_count(),
        num_replicas=len(config['gpu_ids']),
#         rank = local_rank
    )
    
    
    train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=config['batch_size'],
        shuffle=(sampler is None),
        num_workers=config['num_workers'],
        pin_memory=True,
        sampler=sampler,
        drop_last=True
    )
    
    
    # Get the resume iteration to restart training
    #
#     start_iteration = trainer.resume(config['resume']) if config['resume'] else 1
    start_iteration = 1
    print("\n\nStarting epoch: ", start_iteration)

    iterable_train_loader = iter(train_loader)

    if local_rank == 0: 
        time_count = time.time()

    epochs = config['niter'] + 1
    pbar = tqdm(range(start_iteration, epochs), dynamic_ncols=True, smoothing=0.01)
    for iteration in pbar:
        sampler.set_epoch(iteration)
        
        try:
            ground_truth = next(iterable_train_loader)
        except StopIteration:
            iterable_train_loader = iter(train_loader)
            ground_truth = next(iterable_train_loader)

        # Prepare the inputs
        bboxes = random_bbox(config, batch_size=ground_truth.size(0))
        x, mask = mask_image(ground_truth, bboxes, config)

        
        # Move to proper device.
        #
        bboxes = bboxes.cuda(local_rank)
        x = x.cuda(local_rank)
        mask = mask.cuda(local_rank)
        ground_truth = ground_truth.cuda(local_rank)
        

        ###### Forward pass ######
        compute_g_loss = iteration % config['n_critic'] == 0
#         losses, inpainted_result, offset_flow = forward(config, x, bboxes, mask, ground_truth,
#                                                        localD=localD, globalD=globalD,
#                                                        coarse_gen=coarse_generator, fine_gen=fine_generator,
#                                                        local_rank=local_rank, compute_loss_g=compute_g_loss)
        losses, inpainted_result, offset_flow = forward(config, x, bboxes, mask, ground_truth,
                                                       netG=netG, localD=localD, globalD=globalD,
                                                       local_rank=local_rank, compute_loss_g=compute_g_loss)

        
        # Scalars from different devices are gathered into vectors
        #
        for k in losses.keys():
            if not losses[k].dim() == 0:
                losses[k] = torch.mean(losses[k])
                
                
        ###### Backward pass ######
        # Update D
        if not compute_g_loss:
            optimizer_d.zero_grad()
            losses['d'] = losses['wgan_d'] + losses['wgan_gp'] * config['wgan_gp_lambda']
            losses['d'].backward()
            optimizer_d.step() 

        # Update G
        if compute_g_loss:
            optimizer_g.zero_grad()
            losses['g'] = losses['ae'] * config['ae_loss_alpha']
            losses['g'] += losses['l1'] * config['l1_loss_alpha']
            losses['g'] += losses['wgan_g'] * config['gan_loss_alpha']
            losses['g'].backward()
            optimizer_g.step()


        # Set tqdm description
        #
        if local_rank == 0:
            log_losses = ['l1', 'ae', 'wgan_g', 'wgan_d', 'wgan_gp', 'g', 'd']
            message = ' '
            for k in log_losses:
                v = losses.get(k, 0.)
                writer.add_scalar(k, v, iteration)
                message += '%s: %.4f ' % (k, v)

            pbar.set_description(
                (
                    f" {message}"
                )
            )
            
                
        if local_rank == 0:      
            if iteration % (config['viz_iter']) == 0:
                    viz_max_out = config['viz_max_out']
                    if x.size(0) > viz_max_out:
                        viz_images = torch.stack([x[:viz_max_out], inpainted_result[:viz_max_out],
                                                  offset_flow[:viz_max_out]], dim=1)
                    else:
                        viz_images = torch.stack([x, inpainted_result, offset_flow], dim=1)
                    viz_images = viz_images.view(-1, *list(x.size())[1:])
                    vutils.save_image(viz_images,
                                      '%s/niter_%08d.png' % (checkpoint_path, iteration),
                                      nrow=3 * 4,
                                      normalize=True)

            # Save the model
            if iteration % config['snapshot_save_iter'] == 0:
                save_model(
                    netG, globalD, localD, optimizer_g, optimizer_d, checkpoint_path, iteration
                )
Ejemplo n.º 15
0
def train_distributed_v2(config, logger, writer, checkpoint_path):
    
    dist.init_process_group(                                   
        backend='nccl',
#         backend='gloo',
        init_method='env://'
    )  
    
    
    # Find out what GPU on this compute node.
    #
    local_rank = torch.distributed.get_rank()
    
    
    # this is the total # of GPUs across all nodes
    # if using 2 nodes with 4 GPUs each, world size is 8
    #
    world_size = torch.distributed.get_world_size()
    print("### global rank of curr node: {} of {}".format(local_rank, world_size))
    
    
    # For multiprocessing distributed, DistributedDataParallel constructor
    # should always set the single device scope, otherwise,
    # DistributedDataParallel will use all available devices.
    #
    print("local_rank: ", local_rank)
#     dist.barrier()
    torch.cuda.set_device(local_rank)
    print("Creating models on device: ", local_rank)
    
    
    # Various definitions for models, etc.
    #
    input_dim = config['netG']['input_dim']
    cnum = config['netG']['ngf']
    use_cuda = True
    gated = config['netG']['gated']
    batch_size = config['batch_size']
    
    
    # L1 loss used on outputs from course and fine networks in generator.
    #
    loss_l1 = nn.L1Loss(reduction='mean').cuda()
    
    # Models
    #
    netG = Generator(config['netG'], use_cuda=True, device=local_rank).cuda()
    netG = torch.nn.parallel.DistributedDataParallel(
        netG,
        device_ids=[local_rank],
        output_device=local_rank,
        find_unused_parameters=True
    )

    
    patchD = PatchDis(config['netD'], use_cuda=True, device=local_rank).cuda()
    patchD = torch.nn.parallel.DistributedDataParallel(
        patchD,
        device_ids=[local_rank],
        output_device=local_rank,
        find_unused_parameters=True
    )
    
    
    if local_rank == 0:
        logger.info("\n{}".format(netG))
        logger.info("\n{}".format(patchD))
        
    
    # Optimizers
    #
    optimizer_g = torch.optim.Adam(
        netG.parameters(),
        lr=config['lr'],
        betas=(config['beta1'], config['beta2'])
    )
    
    optimizer_d = torch.optim.Adam(
        patchD.parameters(),
        lr=config['lr'],
        betas=(config['beta1'], config['beta2'])
    )
    
    if local_rank == 0:
        logger.info("\n{}".format(netG))
        logger.info("\n{}".format(patchD))
    
    
    # Data
    #
    sampler = None
    train_dataset = Dataset(
        data_path=config['train_data_path'],
        with_subfolder=config['data_with_subfolder'],
        image_shape=config['image_shape'],
        random_crop=config['random_crop']
    )
        
    
    sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset,
#             num_replicas=torch.cuda.device_count(),
        num_replicas=len(config['gpu_ids']),
#         rank = local_rank
    )
    
    
    train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        shuffle=(sampler is None),
        num_workers=config['num_workers'],
        pin_memory=True,
        sampler=sampler,
        drop_last=True
    )
    
    
    losses = {
        'coarse': 0.0,
        'fine': 0.0,
        'ae': 0.0,
        'g_loss': 0.0,
        'd_loss': 0.0
    }
    
    # Get the resume iteration to restart training
    #
    ### TODO:
    ### - allow resuming from checkpoint.
    ###
#     start_iteration = trainer.resume(config['resume']) if config['resume'] else 1
    start_iteration = 1
    print("\n\nStarting epoch: ", start_iteration)

    iterable_train_loader = iter(train_loader)

    if local_rank == 0: 
        time_count = time.time()

    epochs = config['niter'] + 1
    pbar = tqdm(range(start_iteration, epochs), dynamic_ncols=True, smoothing=0.01)
    for iteration in pbar:
        sampler.set_epoch(iteration)
        
        try:
            ground_truth = next(iterable_train_loader)
        except StopIteration:
            iterable_train_loader = iter(train_loader)
            ground_truth = next(iterable_train_loader)
    
        ground_truth = ground_truth.cuda(local_rank)
        mask_ff = random_ff_mask(config['random_ff_settings'], batch_size=batch_size).cuda(local_rank)
        
#         netG.zero_grad()
        imgs_incomplete = ground_truth * (1. - mask_ff) # just background
        x1, x2, offset_flow = netG(imgs_incomplete, mask_ff)
        imgs_complete = (x2 * mask_ff) + imgs_incomplete
        
        
        # Losses 
        #
        coarse_loss = config['l1_loss_alpha'] * loss_l1(ground_truth, x1)
        fine_loss = config['l1_loss_alpha'] * loss_l1(ground_truth, x2)
        ae_loss = coarse_loss + fine_loss
        losses['coarse'] = coarse_loss.item()
        losses['fine'] = fine_loss.item()
        losses['ae'] = ae_loss.item()
        

        
        # Discriminate
        #
        batch_pos_neg = torch.cat([ground_truth, imgs_complete], dim=0) # [N3HW]
        
        # Add in mask and repeat for ground truth and generated completion.
        # Will be split later to produce "real" and "fake" patch features in discriminator
        # for use with hinge loss.
        #
        batch_pos_neg= torch.cat([batch_pos_neg, mask_ff.repeat(2, 1, 1, 1)], dim=1) # [N4HW]
#         patchD.zero_grad()
        pos_neg = patchD(batch_pos_neg)
        
        
        # Losses
        #
        pos, neg = torch.chunk(pos_neg, 2, dim=0)
        g_loss, d_loss = gan_hinge_loss(pos, neg)
        g_loss += ae_loss
        losses['g_loss'] = g_loss.item()
        losses['d_loss'] = d_loss.item()
        
        
        compute_g_loss = iteration % config['n_critic'] == 0
#         # Optimize
#         #
#         if not compute_g_loss:
        optimizer_d.zero_grad()
        d_loss.backward(retain_graph=True)
        optimizer_d.step()
        
        
        pos_neg = patchD(batch_pos_neg)
        pos, neg = torch.chunk(pos_neg, 2, dim=0)
        g_loss, d_loss = gan_hinge_loss(pos, neg)
        g_loss += ae_loss

#         if compute_g_loss:
        optimizer_g.zero_grad()
        g_loss.backward()
        optimizer_g.step()
        

#         print("ae_loss: ", ae_loss, " g_loss: ", g_loss, " d_loss: ", d_loss)
        # Set tqdm description
        #
        if local_rank == 0:
            message = ' '
            for k in losses:
#                 v = losses.get(k, 0.)
                v = losses[k]
#                 writer.add_scalar(k, v, iteration)
                message += '%s: %.4f ' % (k, v)

            pbar.set_description(
                (
                    f" {message}"
                )
            )
        
        # Save output from current iteration.
        #
        if local_rank == 0:      
            if iteration % (config['viz_iter']) == 0:
                    viz_max_out = config['viz_max_out']
                    if ground_truth.size(0) > viz_max_out:
                        viz_images = torch.stack(
                            [ground_truth[:viz_max_out],
                             imgs_incomplete[:viz_max_out],
                             imgs_complete[:viz_max_out],
                             offset_flow[:viz_max_out]],
                             dim=1
                        )
                    else:
                        viz_images = torch.stack(
                            [ground_truth,
                             imgs_incomplete,
                             imgs_complete,
                             offset_flow],
                             dim=1
                        )
                    viz_images = viz_images.view(-1, *list(ground_truth.size())[1:])
                    vutils.save_image(viz_images,
                                      '%s/niter_%08d.png' % (checkpoint_path, iteration),
                                      nrow=2 * 4,
                                      normalize=True)
        
            # Save the model
            if iteration % config['snapshot_save_iter'] == 0:
                save_model_v2(netG, patchD, optimizer_g, optimizer_d, checkpoint_path, iteration)
Ejemplo n.º 16
0
    cudnn.benchmark = True
# print("Arguments: {}".format(args))
print("Use cuda: {}, use gpu_ids: {}".format(cuda, device_ids))

# Set random seed
if args.seed is None:
    args.seed = random.randint(1, 10000)
print("Random seed: {}".format(args.seed))
random.seed(args.seed)
torch.manual_seed(args.seed)
if cuda:
    torch.cuda.manual_seed_all(args.seed)
# print("Configuration: {}".format(config))

# Define the trainer
netG = Generator(config['netG'], cuda, device_ids)
# Resume weight
# if cuda:
#     netG.cuda()
last_model_name = get_model_list(args.checkpoint_path,
                                 "gen",
                                 iteration=args.iter)
last_model_name = args.which_model

# last_model_name = args.which_model
# print("loading model from here --------------> {}".format(last_model_name))


# if not cuda:
#     netG.load_state_dict(torch.load(last_model_name, map_location='cpu'))
# else:
Ejemplo n.º 17
0
def main():
    args = parser.parse_args()
    config = get_config(args.config)

    # CUDA configuration
    cuda = config['cuda']
    device_ids = config['gpu_ids']
    if cuda:
        os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(str(i) for i in device_ids)
        device_ids = list(range(len(device_ids)))
        config['gpu_ids'] = device_ids
        cudnn.benchmark = True

    # Set random seed
    if args.seed is None:
        args.seed = random.randint(1, 10000)
    # print("Random seed: {}".format(args.seed))
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    if cuda:
        torch.cuda.manual_seed_all(args.seed)

    t0 = time.time()
    dataset = datasets.LoadImages(args.image)
    chunker = ImageChunker(config['image_shape'][0], 
                           config['image_shape'][1], 
                           args.overlap)
    try:  # for unexpected error logging
        with torch.no_grad():   # enter no grad context
            # Set checkpoint path
            if not args.checkpoint_path:
                checkpoint_path = os.path.join('checkpoints', config['dataset_name'],
                                               config['mask_type'] + '_' + config['expname'])
            else:
                checkpoint_path = args.checkpoint_path
            last_model_name = get_model_list(checkpoint_path, "gen", iteration=args.iter)

            prev_fname = ''
            vid_writer = None
            for fpath, img_ori, vid_cap in dataset :
                imgs, masks = [], []
                if prev_fname == fpath :
                    frame += 1 # increase frame number if still on the same file
                else :
                    frame = 0 # start frame number
                    _, img_h, img_w = img_ori.shape
                    txtfile = pathlib.Path(fpath).with_suffix('.txt') # Load mask txt file
                    txtfile = os.path.join(args.output, str(txtfile).split('/')[-1])
                    if os.path.exists(txtfile) :
                        bboxes, bframes = load_bbox_txt(txtfile, img_w, img_h)
                    assert len(bboxes) == len(bframes)

                idx = [ii for ii, val in enumerate(bframes) if val==frame]
                bndbxs = [bboxes[ii] for ii in idx]
                img_ori = np.moveaxis(img_ori, 0, -1)
                if len(bndbxs) > 0 : # if any logo detected
                    mask_ori = create_mask(bndbxs, img_w, img_h)
                    # fig, axes = plt.subplots(1,2); axes[0].imshow(img_ori[0]); axes[1].imshow(mask_ori); plt.show()
                    chunked_images = chunker.dimension_preprocess(np.array(deepcopy(img_ori)))
                    chunked_masks = chunker.dimension_preprocess(np.array(deepcopy(mask_ori)))
                    for (x, msk) in zip(chunked_images, chunked_masks) :
                        x = transforms.ToTensor()(x)
                        mask = transforms.ToTensor()(msk)[0].unsqueeze(dim=0)
                        # x = normalize(x)
                        x = x * (1. - mask)
                        x = x.unsqueeze(dim=0)
                        mask = mask.unsqueeze(dim=0)
                        imgs.append(x)
                        masks.append(mask)

                    # Define the trainer
                    netG = Generator(config['netG'], cuda, device_ids)
                    netG.load_state_dict(torch.load(last_model_name))
                    model_iteration = int(last_model_name[-11:-3])
                    # print("Resume from {} at iteration {}".format(checkpoint_path, model_iteration))

                    pred_imgs = []
                    for (x, mask) in zip(imgs, masks) :
                        if torch.max(mask) == 1 :
                            if cuda:
                                netG = nn.parallel.DataParallel(netG, device_ids=device_ids)
                                x = x.cuda()
                                mask = mask.cuda()

                            # Inference
                            x1, x2, offset_flow = netG(x, mask)
                            inpainted_result = x2 * mask + x * (1. - mask)
                            inpainted_result = inpainted_result.squeeze(dim=0).permute(1,2,0).cpu()
                            pred_imgs.append(inpainted_result.numpy())
                        else :
                            pred_imgs.append(x.squeeze(dim=0).permute(1,2,0).numpy())

                    pred_imgs = np.asarray(pred_imgs, dtype=np.float32)
                    reconstructed_image = chunker.dimension_postprocess(pred_imgs, np.array(img_ori))
                    reconstructed_image = np.uint8(reconstructed_image[:, :, ::-1]*255) # BGR to RGB, and rescaling
                else : # no logo detected
                    reconstructed_image = img_ori[:, :, ::-1]

                # Save results (image with detections)
                outname = fpath.split('/')[-1]
                outname = outname.split('.')[0] + '-inp.' + outname.split('.')[-1]
                outpath = os.path.join(args.output, outname)
                if dataset.mode == 'images':
                    cv2.imwrite(outpath, reconstructed_image)
                    print("Saved the inpainted image to {}".format(outpath))
                else :
                    if fpath != prev_fname:  # new video
                        if isinstance(vid_writer, cv2.VideoWriter):
                            vid_writer.release()  # release previous video writer
                            print("Saved the inpainted video to {}".format(outpath))

                        fps = vid_cap.get(cv2.CAP_PROP_FPS)
                        w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
                        h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
                        vid_writer = cv2.VideoWriter(outpath, cv2.VideoWriter_fourcc(*args.fourcc), fps, (w, h))
                    vid_writer.write(reconstructed_image)
                    prev_fname = fpath                
    # exit no grad context
    except Exception as err:  # for unexpected error logging
        print("Error: {}".format(err))
        pass
    print('Inpainting: (%.3fs)' % (time.time() - t0))
Ejemplo n.º 18
0
def generate(img, img_mask_path, model_path):
    with torch.no_grad():   # enter no grad context
        if img_mask_path and is_image_file(img_mask_path):
            # Test a single masked image with a given mask
            x = Image.fromarray(img)
            mask = default_loader(img_mask_path)
            x = transforms.Resize(config['image_shape'][:-1])(x)
            x = transforms.CenterCrop(config['image_shape'][:-1])(x)
            mask = transforms.Resize(config['image_shape'][:-1])(mask)
            mask = transforms.CenterCrop(config['image_shape'][:-1])(mask)
            x = transforms.ToTensor()(x)
            mask = transforms.ToTensor()(mask)[0].unsqueeze(dim=0)
            x = normalize(x)
            x = x * (1. - mask)
            x = x.unsqueeze(dim=0)
            mask = mask.unsqueeze(dim=0)
        elif img_mask_path:
            raise TypeError("{} is not an image file.".format(img_mask_path))
        else:
            # Test a single ground-truth image with a random mask
            #ground_truth = default_loader(img_path)
            ground_truth = img
            ground_truth = transforms.Resize(config['image_shape'][:-1])(ground_truth)
            ground_truth = transforms.CenterCrop(config['image_shape'][:-1])(ground_truth)
            ground_truth = transforms.ToTensor()(ground_truth)
            ground_truth = normalize(ground_truth)
            ground_truth = ground_truth.unsqueeze(dim=0)
            bboxes = random_bbox(config, batch_size=ground_truth.size(0))
            x, mask = mask_image(ground_truth, bboxes, config)

        # Set checkpoint path
        if not model_path:
            checkpoint_path = os.path.join('checkpoints',
                                           config['dataset_name'],
                                           config['mask_type'] + '_' + config['expname'])
        else:
            checkpoint_path = model_path

        # Define the trainer
        netG = Generator(config['netG'], cuda, device_ids)
        # Resume weight
        last_model_name = get_model_list(checkpoint_path, "gen", iteration=0)
        
        if cuda:
            netG.load_state_dict(torch.load(last_model_name))
        else:
            netG.load_state_dict(torch.load(last_model_name, map_location='cpu'))
                                 
        model_iteration = int(last_model_name[-11:-3])
        print("Resume from {} at iteration {}".format(checkpoint_path, model_iteration))

        if cuda:
            netG = nn.parallel.DataParallel(netG, device_ids=device_ids)
            x = x.cuda()
            mask = mask.cuda()

        # Inference
        x1, x2, offset_flow = netG(x, mask)
        inpainted_result = x2 * mask + x * (1. - mask)
        inpainted_result =  from_torch_img_to_numpy(inpainted_result, 'output.png', padding=0, normalize=True)

        return inpainted_result
Ejemplo n.º 19
0
def main():

    args = cfg.parse_args()
    torch.cuda.manual_seed(args.random_seed)

    print(args)

    # create logging folder
    log_path = os.path.join(args.save_path, args.exp_name + '/log')
    model_path = os.path.join(args.save_path, args.exp_name + '/models')
    if not os.path.exists(log_path) and not os.path.exists(model_path):
        os.makedirs(log_path)
        os.makedirs(model_path)
    writer = SummaryWriter(log_path)  # tensorboard

    # load model
    device = torch.device("cuda:0")

    G = Generator(args.d_za, args.d_zm, args.ch_g, args.g_mode,
                  args.use_attention).to(device)
    VD = VideoDiscriminator(args.ch_d).to(device)
    ID = ImageDiscriminator(args.ch_d).to(device)

    G = nn.DataParallel(G)
    VD = nn.DataParallel(VD)
    ID = nn.DataParallel(ID)

    # optimizer
    optimizer_G = torch.optim.Adam(G.parameters(), args.g_lr, (0.5, 0.999))
    optimizer_VD = torch.optim.Adam(VD.parameters(), args.d_lr, (0.5, 0.999))
    optimizer_ID = torch.optim.Adam(ID.parameters(), args.d_lr, (0.5, 0.999))

    # loss
    criterion = nn.BCEWithLogitsLoss().to(device)

    # prepare dataset
    print('==> preparing dataset')
    transform = torchvision.transforms.Compose([
        transforms_vid.ClipResize((args.img_size, args.img_size)),
        transforms_vid.ClipToTensor(),
        transforms_vid.ClipNormalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])

    dataset = UVA(args.data_path, transform=transform)

    dataloader = torch.utils.data.DataLoader(dataset=dataset,
                                             batch_size=args.batch_size,
                                             num_workers=args.num_workers,
                                             shuffle=True,
                                             pin_memory=True,
                                             drop_last=True)

    # for validation
    fixed_za = torch.randn(args.n_za_test, args.d_za, 1, 1, 1).to(device)
    fixed_zm = torch.randn(args.n_zm_test, args.d_zm, 1, 1, 1).to(device)

    print('==> start training')
    for epoch in range(args.max_epoch):
        train(args, epoch, G, VD, ID, optimizer_G, optimizer_VD, optimizer_ID,
              criterion, dataloader, writer, device)

        if epoch % args.val_freq == 0:
            vis(epoch, G, fixed_za, fixed_zm, device, writer)

        if epoch % args.save_freq == 0:
            torch.save(G.state_dict(),
                       os.path.join(model_path, 'G_%d.pth' % (epoch)))
            torch.save(VD.state_dict(),
                       os.path.join(model_path, 'VD_%d.pth' % (epoch)))
            torch.save(ID.state_dict(),
                       os.path.join(model_path, 'ID_%d.pth' % (epoch)))

    return