Exemplo n.º 1
0
def create_optimizer(params, lr, optim_type, optim_eps, optim_a1, optim_a2,
                     lam, verbose):
    optim_params = [{
        'params': [v for k, v in sorted(params.items()) if v.requires_grad]
    }]
    if optim_type == RMSprop_str:
        alpha = optim_a1 if optim_a1 > 0 else 0.99  # pyTorch's default
        eps = optim_eps if optim_eps > 0 else 1e-8  # pyTorch's default
        msg = 'Creating RMSprop optimizer with lr=' + str(lr) + ', lam=' + str(
            lam) + ', alpha=' + str(alpha) + ', eps=' + str(eps)
        optim = RMSprop(optim_params,
                        lr,
                        weight_decay=lam,
                        alpha=alpha,
                        eps=eps)
    elif optim_type == Adam_str:
        # NOTE: not tested.
        eps = optim_eps if optim_eps > 0 else 1e-8  # pyTorch's default
        a1 = optim_a1 if optim_a1 > 0 else 0.9  # pyTorch's default
        a2 = optim_a2 if optim_a2 > 0 else 0.999  # pyTorch's default
        msg = 'Creating Adam optimizer with lr=%s, lam=%s, eps=%s, betas=(%s,%s)' % (
            str(lr), str(lam), str(eps), str(a1), str(a2))
        optim = Adam(optim_params,
                     lr,
                     betas=(a1, a2),
                     eps=eps,
                     weight_decay=lam)
    else:
        raise ValueError('Unknown optim_type: %s' % optim_type)

    if verbose:
        timeLog(msg)

    optim.zero_grad()
    return optim
Exemplo n.º 2
0
def cfggan(opt,
           d_config,
           g_config,
           z_gen,
           loader,
           d_loss=d_loss_dflt,
           g_loss=g_loss_dflt):

    check_opt_(opt)

    write_real(opt, loader)

    optim_config = OptimConfig(opt)
    ddg = DDG(opt, d_config, g_config, z_gen, optim_config)
    ddg.initialize_G(g_loss, opt.cfg_N)

    #---  xICFG
    iterator = None
    for stage in range(opt.num_stages):
        timeLog('xICFG stage %d -----------------' % (stage + 1))
        iterator, diff = ddg.icfg(loader, iterator, d_loss, opt.cfg_U)
        if opt.diff_max > 0 and abs(diff) > opt.diff_max and stage >= 2000:
            timeLog('Stopping as |D(real)-D(gen)| exceeded ' +
                    str(opt.diff_max) + '.')
            break

        if is_time_to_save(opt, stage):
            save_ddg(opt, ddg, stage)
        if is_time_to_generate(opt, stage):
            generate(opt, ddg, stage)

        ddg.approximate(g_loss, opt.cfg_N)
Exemplo n.º 3
0
    def initialize_G(self, g_loss, cfg_N):
        timeLog('DDG::initialize_G ... Initializing tilde(G) ... ')
        z = self.z_gen(1)
        g_out = self.g_net(cast(z), self.g_params, False)
        img_dim = g_out.view(g_out.size(0), -1).size(1)

        batch_size = self.optim_config.x_batch_size
        z_dim = self.z_gen(1).size(1)
        params = {'proj.w': normal_(torch.Tensor(z_dim, img_dim), std=0.01)}
        params['proj.w'].requires_grad = True

        num_gened = 0
        fakes = torch.Tensor(cfg_N, img_dim)
        zs = torch.Tensor(cfg_N, z_dim)
        with torch.no_grad():
            while num_gened < cfg_N:
                num = min(batch_size, cfg_N - num_gened)
                z = self.z_gen(num)
                fake = torch.mm(z, params['proj.w'])
                fakes[num_gened:num_gened + num] = fake
                zs[num_gened:num_gened + num] = z
                num_gened += num

        to_pm1(fakes)  # -> [-1,1]

        sz = [cfg_N] + list(g_out.size())[1:]
        dataset = TensorDataset(zs, fakes.view(sz))
        loader = DataLoader(dataset,
                            batch_size=batch_size,
                            shuffle=True,
                            pin_memory=torch.cuda.is_available())
        self._approximate(loader, g_loss)
Exemplo n.º 4
0
    def icfg(self, loader, iter, d_loss, cfg_U):
        timeLog('DDG::icfg ... ICFG with cfg_U=%d' % cfg_U)
        self.check_trainability()
        t_inc = 1 if self.verbose else 5
        is_train = True
        for t in range(self.num_D()):
            sum_real = sum_fake = count = 0
            for upd in range(cfg_U):
                sample, iter = get_next(loader, iter)

                num = sample[0].size(0)
                fake = self.generate(num, t=t)
                d_out_real = self.d_net(cast(sample[0]), self.d_params,
                                        is_train)
                d_out_fake = self.d_net(cast(fake), self.d_params, is_train)
                loss = d_loss(d_out_real, d_out_fake)
                loss.backward()
                self.d_optimizer.step()
                self.d_optimizer.zero_grad()

                with torch.no_grad():
                    sum_real += float(d_out_real.sum())
                    sum_fake += float(d_out_fake.sum())
                    count += num

            self.store_d_params(t)

            if t_inc > 0 and ((t + 1) % t_inc == 0 or t == self.num_D() - 1):
                logging('  t=%d: real,%s, fake,%s ' %
                        (t + 1, sum_real / count, sum_fake / count))

        raise_if_nan(sum_real)
        raise_if_nan(sum_fake)

        return iter, (sum_real - sum_fake) / count
Exemplo n.º 5
0
 def save(self, opt, path):
     timeLog('Saving: ' + path + ' ... ')
     torch.save(
         dict(d_params_list=self.d_params_list,
              d_params=self.d_params,
              g_params=self.g_params,
              cfg_eta=self.cfg_eta,
              opt=opt), path)
Exemplo n.º 6
0
def write_real(opt, loader):
    timeLog('write_real: ... ')
    dir = 'real'
    if not os.path.exists(dir):
        os.mkdir(dir)

    real, _ = get_next(loader, None)
    real = real[0]
    num = min(10, real.size(0))
    nm = dir + os.path.sep + opt.dataset + '-%dc' % num
    write_image(real[0:num], nm + '.jpg', nrow=5)
Exemplo n.º 7
0
    def approximate(self, g_loss, cfg_N):
        timeLog('DDG::approximate ... cfg_N=%d' % cfg_N)
        batch_size = self.optim_config.x_batch_size
        target_fakes, zs = self.generate(cfg_N,
                                         do_return_z=True,
                                         batch_size=batch_size)
        dataset = TensorDataset(zs, target_fakes)
        loader = DataLoader(dataset,
                            batch_size=batch_size,
                            shuffle=True,
                            pin_memory=torch.cuda.is_available())

        self._approximate(loader, g_loss)
Exemplo n.º 8
0
def get_ds(dataset, dataroot, is_train, do_download, do_augment):
    tr = get_tr(dataset, is_train, do_augment)
    if dataset == 'SVHN':
        if is_train:
            train_ds = datasets.SVHN(dataroot,
                                     split='train',
                                     transform=tr,
                                     download=do_download)
            extra_ds = datasets.SVHN(dataroot,
                                     split='extra',
                                     transform=tr,
                                     download=do_download)
            return ConcatDataset([train_ds, extra_ds])
        else:
            return datasets.SVHN(dataroot,
                                 split='test',
                                 transform=tr,
                                 download=do_download)
    elif dataset == 'MNIST':
        return getattr(datasets, dataset)(dataroot,
                                          train=is_train,
                                          transform=tr,
                                          download=do_download)
    elif dataset.startswith('lsun_') and dataset.endswith('64'):
        nm = dataset[len('lsun_'):len(dataset) -
                     len('64')] + ('_train' if is_train else '_val')
        if nm.startswith('brlr'):
            indexes = list(range(1300000)) if is_train else list(
                range(1300000, 1315802))
            return gen_lsun_balanced(dataroot,
                                     ['bedroom_train', 'living_room_train'],
                                     tr, indexes)
        elif nm.startswith('twbg'):
            indexes = list(range(700000)) if is_train else list(
                range(700000, 708264))
            return gen_lsun_balanced(dataroot, ['tower_train', 'bridge_train'],
                                     tr, indexes)
        else:
            timeLog('Loading LSUN %s ...' % nm)
            return datasets.LSUN(dataroot, classes=[nm], transform=tr)
    else:
        raise ValueError('Unknown dataset: %s ...' % dataset)
Exemplo n.º 9
0
def generate(opt, ddg, stage=''):
    if not opt.gen or opt.num_gen <= 0:
        return

    timeLog('Generating %d ... ' % opt.num_gen)
    stg = '-stg%05d' % (stage + 1) if isinstance(stage, int) else str(stage)

    dir = os.path.dirname(opt.gen)
    if not os.path.exists(dir):
        os.makedirs(dir)

    fake = ddg.generate(opt.num_gen)

    if opt.gen_nrow > 0:
        nm = opt.gen + '%s-%dc' % (stg, opt.num_gen
                                   )  # 'c' for collage or collection
        write_image(fake, nm + '.jpg', nrow=opt.gen_nrow)
    else:
        for i in range(opt.num_gen):
            nm = opt.gen + ('%s-%d' % (stg, i))
            write_image(fake[i], nm + '.jpg')

    timeLog('Done with generating %d ... ' % opt.num_gen)
Exemplo n.º 10
0
    def _approximate(self, loader, g_loss):
        if self.verbose:
            timeLog('DDG::_approximate using %d data points ...' %
                    len(loader.dataset))
        self.check_trainability()
        with torch.no_grad():
            g_params = clone_params(self.g_params, do_copy_requires_grad=True)

        optimizer = self.optim_config.create_optimizer(g_params)
        mtr_loss = tnt.meter.AverageValueMeter()
        last_loss_mean = 99999999
        is_train = True
        for epoch in range(self.optim_config.cfg_x_epo):
            for sample in loader:
                z = cast(sample[0])
                target_fake = cast(sample[1])

                fake = self.g_net(z, g_params, is_train)

                loss = g_loss(fake, target_fake)
                mtr_loss.add(float(loss))
                loss.backward()

                optimizer.step()
                optimizer.zero_grad()

            loss_mean = mtr_loss.value()[0]
            if self.verbose:
                logging('%d ... %s ... ' % (epoch, str(loss_mean)))
            if loss_mean > last_loss_mean:
                self.optim_config.reduce_lr_(optimizer)
            raise_if_nan(loss_mean)

            last_loss_mean = loss_mean
            mtr_loss.reset()

        copy_params(src=g_params, dst=self.g_params)
Exemplo n.º 11
0
def change_lr_(optimizer, lr, verbose=False):
    if verbose:
        timeLog('Setting lr to ' + str(lr) + ' in place ...')
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
Exemplo n.º 12
0
def main():
    parser = ArgParser_HelpWithDefaults(
        description='cfggan_gen',
        formatter_class=argparse.MetavarTypeHelpFormatter)
    parser.add_argument('--seed', type=int, default=1, help='Random seed.')
    parser.add_argument('--gen',
                        type=str,
                        required=True,
                        help='Pathname for writing generated images.')
    parser.add_argument('--saved',
                        type=str,
                        required=True,
                        help='Pathname for the saved model.')
    parser.add_argument('--num_gen',
                        type=int,
                        default=40,
                        help='Number of images to be generated.')
    parser.add_argument(
        '--gen_nrow',
        type=int,
        default=8,
        help=
        'Number of images in each row when making a collage. -1: No collage.')

    gen_opt = parser.parse_args()
    show_args(gen_opt, ['seed', 'gen', 'saved', 'num_gen'])

    torch.manual_seed(gen_opt.seed)
    np.random.seed(gen_opt.seed)

    timeLog('Reading from %s ... ' % gen_opt.saved)
    from_file = torch.load(
        gen_opt.saved,
        map_location=None if torch.cuda.is_available() else 'cpu')

    opt = from_file['opt']

    #---  these must be in sync with cfggan_train  --------------
    def d_config(requires_grad):  # D
        if opt.d_model == DCGANx:
            return netdef.dcganx_D(opt.d_dim,
                                   opt.image_size,
                                   opt.channels,
                                   opt.norm_type,
                                   requires_grad,
                                   depth=opt.d_depth,
                                   do_bias=not opt.do_no_bias)
        elif opt.d_model == Resnet4:
            return netdef.resnet4_D(opt.d_dim,
                                    opt.image_size,
                                    opt.channels,
                                    opt.norm_type,
                                    requires_grad,
                                    do_bias=not opt.do_no_bias)
        else:
            raise ValueError('Unknown d_model: %s' % opt.d_model)

    def g_config(requires_grad):  # G
        if opt.g_model == DCGANx:
            return netdef.dcganx_G(opt.z_dim,
                                   opt.g_dim,
                                   opt.image_size,
                                   opt.channels,
                                   opt.norm_type,
                                   requires_grad,
                                   depth=opt.g_depth,
                                   do_bias=not opt.do_no_bias)
        elif opt.g_model == Resnet4:
            return netdef.resnet4_G(opt.z_dim,
                                    opt.g_dim,
                                    opt.image_size,
                                    opt.channels,
                                    opt.norm_type,
                                    requires_grad,
                                    do_bias=not opt.do_no_bias)
        elif opt.g_model == FCn:
            return netdef.fcn_G(opt.z_dim,
                                opt.g_dim,
                                opt.image_size,
                                opt.channels,
                                requires_grad,
                                depth=opt.g_depth)
        else:
            raise ValueError('Unknown g_model: %s' % opt.g_model)

    def z_gen(num):
        return normal_(torch.Tensor(num, opt.z_dim), std=opt.z_std)

    #-------------------------------------------------------------

    ddg = DDG(opt,
              d_config,
              g_config,
              z_gen,
              optim_config=None,
              from_file=from_file)
    generate(gen_opt, ddg)
Exemplo n.º 13
0
def proc(opt):
    check_opt_(opt)

    torch.manual_seed(opt.seed)
    np.random.seed(opt.seed)

    ds_attr = get_ds_attr(opt.dataset)
    opt.image_size = ds_attr['image_size']
    opt.channels = ds_attr['channels']

    def d_config(requires_grad):  # D
        if opt.d_model == DCGANx:
            return netdef.dcganx_D(opt.d_dim,
                                   opt.image_size,
                                   opt.channels,
                                   opt.norm_type,
                                   requires_grad,
                                   depth=opt.d_depth,
                                   do_bias=not opt.do_no_bias)
        elif opt.d_model == Resnet4:
            if opt.d_depth != 4:
                logging('WARNING: d_depth is ignored as d_model is Resnet4.')
            return netdef.resnet4_D(opt.d_dim,
                                    opt.image_size,
                                    opt.channels,
                                    opt.norm_type,
                                    requires_grad,
                                    do_bias=not opt.do_no_bias)
        else:
            raise ValueError('d_model must be dcganx.')

    def g_config(requires_grad):  # G
        if opt.g_model == DCGANx:
            return netdef.dcganx_G(opt.z_dim,
                                   opt.g_dim,
                                   opt.image_size,
                                   opt.channels,
                                   opt.norm_type,
                                   requires_grad,
                                   depth=opt.g_depth,
                                   do_bias=not opt.do_no_bias)
        elif opt.g_model == Resnet4:
            if opt.g_depth != 4:
                logging('WARNING: d_depth is ignored as d_model is Resnet4.')
            return netdef.resnet4_G(opt.z_dim,
                                    opt.g_dim,
                                    opt.image_size,
                                    opt.channels,
                                    opt.norm_type,
                                    requires_grad,
                                    do_bias=not opt.do_no_bias)
        elif opt.g_model == FCn:
            return netdef.fcn_G(opt.z_dim,
                                opt.g_dim,
                                opt.image_size,
                                opt.channels,
                                requires_grad,
                                depth=opt.g_depth)
        else:
            raise ValueError('g_model must be dcganx or fcn.')

    def z_gen(num):
        return normal_(torch.Tensor(num, opt.z_dim), std=opt.z_std)

    ds = get_ds(opt.dataset,
                opt.dataroot,
                is_train=True,
                do_download=opt.do_download,
                do_augment=opt.do_augment)
    timeLog('#train = %d' % len(ds))
    loader = DataLoader(ds,
                        opt.batch_size,
                        shuffle=True,
                        drop_last=True,
                        num_workers=opt.num_workers,
                        pin_memory=torch.cuda.is_available())

    cfggan(opt, d_config, g_config, z_gen, loader)