示例#1
0
def _main() -> None:
    """スクリプトのエントリポイント
    """
    import logging

    logging.basicConfig(level=logging.INFO)

    tfv1.enable_eager_execution()

    raw_train, raw_validation, _, metadata = datasets.get_batch_dataset(shuffle_seed=0)
    base_learning_rate = 0.0001
    model = network_ft.MobileNetV2FT()
    model.compile(
        optimizer=tf.keras.optimizers.RMSprop(lr=base_learning_rate),
        loss="binary_crossentropy",
        metrics=["accuracy"],
    )

    # initial model accuracy
    loss0, accuracy0 = model.evaluate(raw_validation, steps=20)
    logger.info(f"initial loss: {loss0:.2f}, acc: {accuracy0:.2f}")

    # training
    checkpoint = utils.load_checkpoints(model, save_dir="_data/ckpt_finetuning")
    history = model.fit(
        raw_train, epochs=2, validation_data=raw_validation, callbacks=[checkpoint]
    )
    utils.plot_history(history)
def main(hparams):

    # Set up some stuff accoring to hparams
    hparams.n_input = np.prod(hparams.image_shape)
    utils.set_num_measurements(hparams)
    utils.print_hparams(hparams)

    # get inputs
    data_dict = model_input(hparams)

    estimator = utils.get_estimator(hparams, hparams.model_types[0])
    print(estimator)
    hparams.checkpoint_dir = utils.setup_checkpointing(hparams)
    measurement_losses, l2_losses = utils.load_checkpoints(hparams)

    h_hats_dict = {model_type: {} for model_type in hparams.model_types}
    for key, x in data_dict.iteritems():
        if not hparams.not_lazy:
            # If lazy, first check if the image has already been
            # saved before by *all* estimators. If yes, then skip this image.
            save_paths = utils.get_save_paths(hparams, key)
            is_saved = all([
                os.path.isfile(save_path) for save_path in save_paths.values()
            ])
            if is_saved:
                continue

        # Get Rx data
        Rx = data_dict[key]['Rx_data']
        Tx = data_dict[key]['Tx_data']
        H = data_dict[key]['H_data']
        Pilot_Rx = utils.get_pilot(Rx)
        print('Pilot_shape', Pilot_Rx.shape)
        Pilot_Rx = Pilot_Rx[0::2] + Pilot_Rx[1::2] * 1j
        Pilot_Tx = utils.get_pilot(Tx)
        Pilot_Tx = Pilot_Tx[0::2] + Pilot_Tx[1::2] * 1j
        Pilot_complex = Pilot_Rx / Pilot_Tx
        Pilot = np.empty((Pilot_complex.size * 2, ), dtype=Pilot_Rx.dtype)
        Pilot[0::2] = np.real(Pilot_complex)
        Pilot[1::2] = np.imag(Pilot_complex)

        Pilot = np.reshape(Pilot, [1, -1]) / 2.5
        # Construct estimates using each estimator
        h_hat = estimator(Tx, Rx, Pilot, hparams)

        # Compute and store measurement and l2 loss
        #        measurement_losses['dcgan'][key] = utils.get_measurement_loss(h_hat, Tx, Rx)
        #        l2_losses['dcgan'][key] = utils.get_l2_loss(h_hat, H)

        print "Processed upto image {0} / {1}".format(key + 1, len(data_dict))

        # Checkpointing
        if (hparams.save_images) and ((key + 1) % hparams.checkpoint_iter
                                      == 0):
            # utils.checkpoint(key,h_hat, measurement_losses, l2_losses, save_image, hparams)
            utils.save_channel_image(key + 1, h_hat, hparams)
            utils.save_channel_mat(key + 1, h_hat, hparams)

            print '\nProcessed and saved first ', key + 1, 'channels\n'
    def __init__(self, config):
        super(FCN, self).__init__()

        if config.network == 'fcn32s':
            vgg16 = torchvision.models.vgg16(pretrained=(config.load_iter == 0))
            self.network = FCN32s(vgg16)
        elif config.network == 'fcn16s':
            vgg16 = torchvision.models.vgg16(pretrained=False)
            fcn32s = FCN32s(vgg16)
            if config.load_iter == 0:
                load_checkpoints(fcn32s, 'fcn32s', config.checkpoint_dir, 100000)
            self.network = FCN16s(fcn32s)

        if config.is_train:
            self.criterion = nn.CrossEntropyLoss(ignore_index=255)
            self.optimizer = optim.SGD(self.network.parameters(), lr=config.lr,
                                       momentum=config.momentum, weight_decay=config.weight_decay)
def main():
    if not os.path.exists(os.path.join(config.tensorboard_dir, config.name)):
        os.makedirs(os.path.join(config.tensorboard_dir, config.name))
    if not os.path.exists(os.path.join(config.checkpoint_dir, config.name)):
        os.makedirs(os.path.join(config.checkpoint_dir, config.name))

    device = torch.device('cuda:0' if config.use_cuda else 'cpu')
    models = GAN(config).to(device)
    if config.load_epoch != 0:
        load_checkpoints(models, config.checkpoint_dir, config.name,
                         config.load_epoch)

    if config.is_train:
        models.train()
        writer = SummaryWriter(
            log_dir=os.path.join(config.tensorboard_dir, config.name))
        train(models, writer, device)
    else:
        models.eval()
        test(models, device)
示例#5
0
def main(hparams):

    # Set up some stuff accoring to hparams
    hparams.n_input = np.prod(hparams.image_shape)
    utils.set_num_measurements(hparams)
    utils.print_hparams(hparams)

    # get inputs
    data_dict = model_input(hparams)

    estimator = utils.get_estimator(hparams, 'vae')
    utils.setup_checkpointing(hparams)
    measurement_losses, l2_losses = utils.load_checkpoints(hparams)

    h_hats_dict = {model_type: {} for model_type in hparams.model_types}
    for key, x in data_dict.iteritems():
        if not hparams.not_lazy:
            # If lazy, first check if the image has already been
            # saved before by *all* estimators. If yes, then skip this image.
            save_paths = utils.get_save_paths(hparams, key)
            is_saved = all([
                os.path.isfile(save_path) for save_path in save_paths.values()
            ])
            if is_saved:
                continue

        # Get Rx data
        Rx = data_dict[key]['Rx_data']
        Tx = data_dict[key]['Rx_data']
        H = data_dict[key]['H_data']

        # Construct estimates using each estimator
        h_hat = estimator(Tx, Rx, hparams)

        # Save the estimate
        h_hats_dict['vae'][key] = h_hat

        # Compute and store measurement and l2 loss
        measurement_losses['vae'][key] = utils.get_measurement_loss(
            h_hat, Tx, Rx)
        l2_losses['vae'][key] = utils.get_l2_loss(h_hat, H)

        print 'Processed upto image {0} / {1}'.format(key + 1, len(data_dict))

        # Checkpointing
        if (hparams.save_images) and ((key + 1) % hparams.checkpoint_iter
                                      == 0):
            utils.checkpoint(key, h_hat, measurement_losses, l2_losses,
                             save_image, hparams)
            print '\nProcessed and saved first ', key + 1, 'channels\n'
def get_topk(color_info, k):
    colors = list(color_info.values())
    return list(map(lambda x: x[k], colors))


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# out_root = './data/colorize_result'
# if not os.path.exists(out_root):
#     os.mkdir(out_root)
generator = 'deepunetG_030.pth.tar'

model = DeepUNetPaintGenerator()
model = model.to(device)
load_checkpoints(generator, model, device_type=device.type)
for param in model.parameters():
    param.requires_grad = False


def act(stylefile, testfile):
    if len(sys.argv) < 3:
        raise RuntimeError(
            'Command Line Argument Must be (sketch file, style file)')

    style_f = stylefile
    # './data/styles/%s' % sys.argv[2]
    test_f = testfile
    # './data/test/%s' % sys.argv[1]

    filename = sys.argv[1][:-4] + sys.argv[2][:-4] + '.png'
def main(hparams):
    # set up perceptual loss
    device = 'cuda:0'
    percept = PerceptualLoss(
            model="net-lin", net="vgg", use_gpu=device.startswith("cuda")
    )

    utils.print_hparams(hparams)

    # get inputs
    xs_dict = model_input(hparams)

    estimators = utils.get_estimators(hparams)
    utils.setup_checkpointing(hparams)
    measurement_losses, l2_losses, lpips_scores, z_hats = utils.load_checkpoints(hparams)

    x_hats_dict = {model_type : {} for model_type in hparams.model_types}
    x_batch_dict = {}

    A = utils.get_A(hparams)
    noise_batch = hparams.noise_std * np.random.standard_t(2, size=(hparams.batch_size, hparams.num_measurements))



    for key, x in xs_dict.items():
        if not hparams.not_lazy:
            # If lazy, first check if the image has already been
            # saved before by *all* estimators. If yes, then skip this image.
            save_paths = utils.get_save_paths(hparams, key)
            is_saved = all([os.path.isfile(save_path) for save_path in save_paths.values()])
            if is_saved:
                continue

        x_batch_dict[key] = x
        if len(x_batch_dict) < hparams.batch_size:
            continue

        # Reshape input
        x_batch_list = [x.reshape(1, hparams.n_input) for _, x in x_batch_dict.items()]
        x_batch = np.concatenate(x_batch_list)

        # Construct noise and measurements


        y_batch = utils.get_measurements(x_batch, A, noise_batch, hparams)

        # Construct estimates using each estimator
        for model_type in hparams.model_types:
            estimator = estimators[model_type]
            x_hat_batch, z_hat_batch, m_loss_batch = estimator(A, y_batch, hparams)

            for i, key in enumerate(x_batch_dict.keys()):
                x = xs_dict[key]
                y_train = y_batch[i]
                x_hat = x_hat_batch[i]

                # Save the estimate
                x_hats_dict[model_type][key] = x_hat

                # Compute and store measurement and l2 loss
                measurement_losses[model_type][key] = m_loss_batch[key]
                l2_losses[model_type][key] = utils.get_l2_loss(x_hat, x)
                lpips_scores[model_type][key] = utils.get_lpips_score(percept, x_hat, x, hparams.image_shape)
                z_hats[model_type][key] = z_hat_batch[i]

        print('Processed upto image {0} / {1}'.format(key+1, len(xs_dict)))

        # Checkpointing
        if (hparams.save_images) and ((key+1) % hparams.checkpoint_iter == 0):
            utils.checkpoint(x_hats_dict, measurement_losses, l2_losses, lpips_scores, z_hats, save_image, hparams)
            x_hats_dict = {model_type : {} for model_type in hparams.model_types}
            print('\nProcessed and saved first ', key+1, 'images\n')

        x_batch_dict = {}

    # Final checkpoint
    if hparams.save_images:
        utils.checkpoint(x_hats_dict, measurement_losses, l2_losses, lpips_scores, z_hats, save_image, hparams)
        print('\nProcessed and saved all {0} image(s)\n'.format(len(xs_dict)))

    if hparams.print_stats:
        for model_type in hparams.model_types:
            print(model_type)
            measurement_loss_list = list(measurement_losses[model_type].values())
            l2_loss_list = list(l2_losses[model_type].values())
            mean_m_loss = np.mean(measurement_loss_list)
            mean_l2_loss = np.mean(l2_loss_list)
            print('mean measurement loss = {0}'.format(mean_m_loss))
            print('mean l2 loss = {0}'.format(mean_l2_loss))

    if hparams.image_matrix > 0:
        utils.image_matrix(xs_dict, x_hats_dict, view_image, hparams)

    # Warn the user that some things were not processsed
    if len(x_batch_dict) > 0:
        print('\nDid NOT process last {} images because they did not fill up the last batch.'.format(len(x_batch_dict)))
        print('Consider rerunning lazily with a smaller batch size.')
示例#8
0
def main(hparams):
    hparams.n_input = np.prod(hparams.image_shape)
    hparams.model_type = 'vae'
    maxiter = hparams.max_outer_iter
    utils.print_hparams(hparams)
    xs_dict = model_input(hparams)  # returns the images
    estimators = utils.get_estimators(hparams)
    utils.setup_checkpointing(hparams)
    measurement_losses, l2_losses = utils.load_checkpoints(hparams)

    x_hats_dict = {'vae': {}}
    x_batch_dict = {}

    for key, x in xs_dict.iteritems():
        print key
        x_batch_dict[key] = x  #placing images in dictionary
        if len(x_batch_dict) < hparams.batch_size:
            continue
        x_coll = [
            x.reshape(1, hparams.n_input) for _, x in x_batch_dict.iteritems()
        ]  #Generates the columns of input x
        x_batch = np.concatenate(x_coll)  # Generates entire X

        A_outer = utils.get_outer_A(hparams)  # Created the random matric A

        noise_batch = hparams.noise_std * np.random.randn(
            hparams.batch_size, 100)

        y_batch_outer = np.sign(
            np.matmul(x_batch, A_outer)
        )  # Multiplication of A and X followed by quantization on 4 levels

        #y_batch_outer = np.matmul(x_batch, A_outer)

        x_main_batch = 0.0 * x_batch
        z_opt_batch = np.random.randn(hparams.batch_size,
                                      20)  #Input to the generator of the GAN

        for k in range(maxiter):

            x_est_batch = x_main_batch + hparams.outer_learning_rate * (
                np.matmul(
                    (y_batch_outer -
                     np.sign(np.matmul(x_main_batch, A_outer))), A_outer.T))
            #x_est_batch = x_main_batch + hparams.outer_learning_rate * (np.matmul((y_batch_outer - np.matmul(x_main_batch, A_outer)), A_outer.T))
            # Gradient decent in x is done
            estimator = estimators['vae']
            x_hat_batch, z_opt_batch = estimator(
                x_est_batch, z_opt_batch, hparams)  # Projectin on the GAN
            x_main_batch = x_hat_batch

        dist = np.linalg.norm(x_batch - x_main_batch) / 784
        print 'cool'
        print dist

        for i, key in enumerate(x_batch_dict.keys()):
            x = xs_dict[key]
            y = y_batch_outer[i]
            x_hat = x_hat_batch[i]

            # Save the estimate
            x_hats_dict['vae'][key] = x_hat

            # Compute and store measurement and l2 loss
            measurement_losses['vae'][key] = utils.get_measurement_loss(
                x_hat, A_outer, y)
            l2_losses['vae'][key] = utils.get_l2_loss(x_hat, x)
        print 'Processed upto image {0} / {1}'.format(key + 1, len(xs_dict))

        # Checkpointing
        if (hparams.save_images) and ((key + 1) % hparams.checkpoint_iter
                                      == 0):
            utils.checkpoint(x_hats_dict, measurement_losses, l2_losses,
                             save_image, hparams)
            #x_hats_dict = {'dcgan' : {}}
            print '\nProcessed and saved first ', key + 1, 'images\n'

        x_batch_dict = {}

    # Final checkpoint
    if hparams.save_images:
        utils.checkpoint(x_hats_dict, measurement_losses, l2_losses,
                         save_image, hparams)
        print '\nProcessed and saved all {0} image(s)\n'.format(len(xs_dict))

    if hparams.print_stats:
        for model_type in hparams.model_types:
            print model_type
            mean_m_loss = np.mean(measurement_losses[model_type].values())
            mean_l2_loss = np.mean(l2_losses[model_type].values())
            print 'mean measurement loss = {0}'.format(mean_m_loss)
            print 'mean l2 loss = {0}'.format(mean_l2_loss)

    if hparams.image_matrix > 0:
        utils.image_matrix(xs_dict, x_hats_dict, view_image, hparams)

    # Warn the user that some things were not processsed
    if len(x_batch_dict) > 0:
        print '\nDid NOT process last {} images because they did not fill up the last batch.'.format(
            len(x_batch_dict))
        print 'Consider rerunning lazily with a smaller batch size.'
示例#9
0
def main(hparams):
    # Set up some stuff according to hparams
    hparams.n_input = np.prod(hparams.image_shape)
    maxiter = hparams.max_outer_iter
    utils.print_hparams(hparams)

    # get inputs
    xs_dict = model_input(hparams)

    estimators = utils.get_estimators(hparams)
    utils.setup_checkpointing(hparams)
    measurement_losses, l2_losses = utils.load_checkpoints(hparams)

    x_hats_dict = {'dcgan' : {}}
    x_batch_dict = {}
    for key, x in xs_dict.iteritems():
        if hparams.lazy:
            # If lazy, first check if the image has already been
            # saved before by *all* estimators. If yes, then skip this image.
            save_paths = utils.get_save_paths(hparams, key)
            is_saved = all([os.path.isfile(save_path) for save_path in save_paths.values()])
            if is_saved:
                continue

        x_batch_dict[key] = x
        if len(x_batch_dict) < hparams.batch_size:
            continue

        # Reshape input
        x_batch_list = [x.reshape(1, hparams.n_input) for _, x in x_batch_dict.iteritems()]
        x_batch = np.concatenate(x_batch_list)

        # Construct measurements
        A_outer = utils.get_outer_A(hparams)

        y_batch_outer=np.matmul(x_batch, A_outer)


        x_main_batch = 0.0 * x_batch
        z_opt_batch = np.random.randn(hparams.batch_size, 100)
        for k in range(maxiter):

            x_est_batch=x_main_batch + hparams.outer_learning_rate*(np.matmul((y_batch_outer-np.matmul(x_main_batch,A_outer)),A_outer.T))



            estimator = estimators['dcgan']
            x_hat_batch,z_opt_batch = estimator(x_est_batch,z_opt_batch, hparams)
            x_main_batch=x_hat_batch


        for i, key in enumerate(x_batch_dict.keys()):
            x = xs_dict[key]
            y = y_batch_outer[i]
            x_hat = x_hat_batch[i]

            # Save the estimate
            x_hats_dict['dcgan'][key] = x_hat

            # Compute and store measurement and l2 loss
            measurement_losses['dcgan'][key] = utils.get_measurement_loss(x_hat, A_outer, y)
            l2_losses['dcgan'][key] = utils.get_l2_loss(x_hat, x)
        print 'Processed upto image {0} / {1}'.format(key+1, len(xs_dict))

        # Checkpointing
        if (hparams.save_images) and ((key+1) % hparams.checkpoint_iter == 0):
            utils.checkpoint(x_hats_dict, measurement_losses, l2_losses, save_image, hparams)
            #x_hats_dict = {'dcgan' : {}}
            print '\nProcessed and saved first ', key+1, 'images\n'

        x_batch_dict = {}

    # Final checkpoint
    if hparams.save_images:
        utils.checkpoint(x_hats_dict, measurement_losses, l2_losses, save_image, hparams)
        print '\nProcessed and saved all {0} image(s)\n'.format(len(xs_dict))

    if hparams.print_stats:
        for model_type in hparams.model_types:
            print model_type
            mean_m_loss = np.mean(measurement_losses[model_type].values())
            mean_l2_loss = np.mean(l2_losses[model_type].values())
            print 'mean measurement loss = {0}'.format(mean_m_loss)
            print 'mean l2 loss = {0}'.format(mean_l2_loss)

    if hparams.image_matrix > 0:
        utils.image_matrix(xs_dict, x_hats_dict, view_image, hparams)

    # Warn the user that some things were not processsed
    if len(x_batch_dict) > 0:
        print '\nDid NOT process last {} images because they did not fill up the last batch.'.format(len(x_batch_dict))
        print 'Consider rerunning lazily with a smaller batch size.'
from utils import load_checkpoints

from preprocess import PairedDataset

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

out_root = './data/attention_result'
if not os.path.exists(out_root):
    os.mkdir(out_root)
generator = 'deepunetG_015.pth.tar'

model = DeepUNetPaintGenerator()
for param in model.parameters():
    param.requires_grad = False
model = model.to(device)
load_checkpoints(generator, model)

val_data = PairedDataset(
    transform=transforms.ToTensor(),
    mode='val',
    color_histogram=True,
)
length = len(val_data)
idxs = random.sample(range(0, length - 1), 3400)

targets = idxs[0:1700]
styles = idxs[1700:3400]

to_pil = transforms.ToPILImage()

for i, (target, style) in enumerate(zip(targets, styles)):
def main(hparams):

    # Set up some stuff accoring to hparams
    hparams.n_input = np.prod(hparams.image_shape)
    utils.set_num_measurements(hparams)
    utils.print_hparams(hparams)

    # get inputs
    xs_dict = model_input(hparams)

    estimators = utils.get_estimators(hparams)
    utils.setup_checkpointing(hparams)
    measurement_losses, l2_losses = utils.load_checkpoints(hparams)

    x_hats_dict = {model_type: {} for model_type in hparams.model_types}
    x_batch_dict = {}
    for key, x in xs_dict.iteritems():
        if not hparams.not_lazy:
            # If lazy, first check if the image has already been
            # saved before by *all* estimators. If yes, then skip this image.
            save_paths = utils.get_save_paths(hparams, key)
            is_saved = all([
                os.path.isfile(save_path) for save_path in save_paths.values()
            ])
            if is_saved:
                continue

        x_batch_dict[key] = x
        if len(x_batch_dict) < hparams.batch_size:
            continue

        # Reshape input
        x_batch_list = [
            x.reshape(1, hparams.n_input) for _, x in x_batch_dict.iteritems()
        ]
        x_batch = np.concatenate(x_batch_list)

        # Construct noise and measurements
        A = utils.get_A(hparams)
        noise_batch = hparams.noise_std * np.random.randn(
            hparams.batch_size, hparams.num_measurements)
        if hparams.measurement_type == 'project':
            y_batch = x_batch + noise_batch
        else:
            y_batch = np.matmul(x_batch, A) + noise_batch

        # Construct estimates using each estimator
        for model_type in hparams.model_types:
            estimator = estimators[model_type]
            x_hat_batch = estimator(A, y_batch, hparams)

            for i, key in enumerate(x_batch_dict.keys()):
                x = xs_dict[key]
                y = y_batch[i]
                x_hat = x_hat_batch[i]

                # Save the estimate
                x_hats_dict[model_type][key] = x_hat

                # Compute and store measurement and l2 loss
                measurement_losses[model_type][
                    key] = utils.get_measurement_loss(x_hat, A, y)
                l2_losses[model_type][key] = utils.get_l2_loss(x_hat, x)

        print('Processed upto image {0} / {1}'.format(key + 1, len(xs_dict)))

        # Checkpointing
        if (hparams.save_images) and ((key + 1) % hparams.checkpoint_iter
                                      == 0):
            utils.checkpoint(x_hats_dict, measurement_losses, l2_losses,
                             save_image, hparams)
            x_hats_dict = {
                model_type: {}
                for model_type in hparams.model_types
            }
            print('\nProcessed and saved first ', key + 1, 'images\n')

        x_batch_dict = {}

    # Final checkpoint
    if hparams.save_images:
        utils.checkpoint(x_hats_dict, measurement_losses, l2_losses,
                         save_image, hparams)
        print('\nProcessed and saved all {0} image(s)\n'.format(len(xs_dict)))

    if hparams.print_stats:
        for model_type in hparams.model_types:
            print(model_type)
            mean_m_loss = np.mean(measurement_losses[model_type].values())
            mean_l2_loss = np.mean(l2_losses[model_type].values())
            print('mean measurement loss = {0}'.format(mean_m_loss))
            print('mean l2 loss = {0}'.format(mean_l2_loss))

    if hparams.image_matrix > 0:
        utils.image_matrix(xs_dict, x_hats_dict, view_image, hparams)

    # Warn the user that some things were not processsed
    if len(x_batch_dict) > 0:
        print(
            '\nDid NOT process last {} images because they did not fill up the last batch.'
            .format(len(x_batch_dict)))
        print('Consider rerunning lazily with a smaller batch size.')
示例#12
0
def main(hparams):
    hparams.n_input = np.prod(hparams.image_shape)
    maxiter = hparams.max_outer_iter
    utils.print_hparams(hparams)
    xs_dict = model_input(hparams)
    estimators = utils.get_estimators(hparams)
    utils.setup_checkpointing(hparams)
    measurement_losses, l2_losses = utils.load_checkpoints(hparams)
    x_hats_dict = {'dcgan': {}}
    x_batch_dict = {}
    for key, x in xs_dict.iteritems():
        x_batch_dict[key] = x
        if len(x_batch_dict) < hparams.batch_size:
            continue
        x_coll = [
            x.reshape(1, hparams.n_input) for _, x in x_batch_dict.iteritems()
        ]
        x_batch = np.concatenate(x_coll)
        A_outer = utils.get_outer_A(hparams)
        # 1bitify
        y_batch_outer = np.sign(np.matmul(x_batch, A_outer))

        x_main_batch = 0.0 * x_batch
        z_opt_batch = np.random.randn(hparams.batch_size, 100)
        for k in range(maxiter):
            x_est_batch = x_main_batch + hparams.outer_learning_rate * (
                np.matmul(
                    (y_batch_outer -
                     np.sign(np.matmul(x_main_batch, A_outer))), A_outer.T))
            estimator = estimators['dcgan']
            x_hat_batch, z_opt_batch = estimator(x_est_batch, z_opt_batch,
                                                 hparams)
            x_main_batch = x_hat_batch

        for i, key in enumerate(x_batch_dict.keys()):
            x = xs_dict[key]
            y = y_batch_outer[i]
            x_hat = x_hat_batch[i]
            x_hats_dict['dcgan'][key] = x_hat
            measurement_losses['dcgan'][key] = utils.get_measurement_loss(
                x_hat, A_outer, y)
            l2_losses['dcgan'][key] = utils.get_l2_loss(x_hat, x)
        print 'Processed upto image {0} / {1}'.format(key + 1, len(xs_dict))
        if (hparams.save_images) and ((key + 1) % hparams.checkpoint_iter
                                      == 0):
            utils.checkpoint(x_hats_dict, measurement_losses, l2_losses,
                             save_image, hparams)
            print '\nProcessed and saved first ', key + 1, 'images\n'

        x_batch_dict = {}

    if hparams.save_images:
        utils.checkpoint(x_hats_dict, measurement_losses, l2_losses,
                         save_image, hparams)
        print '\nProcessed and saved all {0} image(s)\n'.format(len(xs_dict))

    if hparams.print_stats:
        for model_type in hparams.model_types:
            print model_type
            mean_m_loss = np.mean(measurement_losses[model_type].values())
            mean_l2_loss = np.mean(l2_losses[model_type].values())
            print 'mean measurement loss = {0}'.format(mean_m_loss)
            print 'mean l2 loss = {0}'.format(mean_l2_loss)

    if hparams.image_matrix > 0:
        utils.image_matrix(xs_dict, x_hats_dict, view_image, hparams)

    # Warn the user that some things were not processsed
    if len(x_batch_dict) > 0:
        print '\nDid NOT process last {} images because they did not fill up the last batch.'.format(
            len(x_batch_dict))
        print 'Consider rerunning lazily with a smaller batch size.'
示例#13
0
def test(cfg, logger, vis):
    torch.cuda.manual_seed_all(66)
    torch.manual_seed(66)

    # Setup model, optimizer and loss function
    model_cls = get_model(cfg['model'])
    model = model_cls(cfg).to(device)

    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {
        k: v
        for k, v in cfg["optimizer"].items() if k != "name"
    }
    optimizer = optimizer_cls(model.parameters(), **optimizer_params)

    crit = get_critical(cfg['critical'])().to(device)
    ssim = SSIM().to(device)

    model.eval()
    _, step = load_checkpoints(model,
                               optimizer,
                               cfg['checkpoint_dir'],
                               name='latest')

    # Setup Dataloader
    data_loader = get_loader(cfg["data"]["dataset"])
    data_path = cfg["data"]["path"]

    test_loader = data_loader(data_path,
                              split=cfg["data"]["test_split"],
                              patch_size=cfg['data']['patch_size'],
                              augmentation=cfg['data']['aug_data'])

    testloader = DataLoader(
        test_loader,
        batch_size=cfg["batch_size"],
        num_workers=cfg["n_workers"],
        shuffle=True,
    )

    all_num = 0
    all_losses = {}
    for i, batch in enumerate(testloader):

        O, B = batch
        O, B = Variable(O.to(device),
                        requires_grad=False), Variable(B.to(device),
                                                       requires_grad=False)
        R = O - B

        with torch.no_grad():
            O_Rs = model(O)
        loss_list = [crit(O_R, R) for O_R in O_Rs]
        ssim_list = [ssim(O - O_R, O - R) for O_R in O_Rs]

        losses = {
            'loss%d' % i: loss.item()
            for i, loss in enumerate(loss_list)
        }
        ssimes = {
            'ssim%d' % i: ssim.item()
            for i, ssim in enumerate(ssim_list)
        }
        losses.update(ssimes)

        prediction = O - O_Rs[-1]

        batch_size = O.size(0)

        all_num += batch_size
        for key, val in losses.items():
            if i == 0:
                all_losses[key] = 0.
            all_losses[key] += val * batch_size
            logger.info('batch %d loss %s: %f' % (i, key, val))

        if vis is not None:
            for k, v in losses.items():
                vis.plot(k, v)
            vis.images(np.clip((prediction.detach().data * 255).cpu().numpy(),
                               0, 255),
                       win='pred')
            vis.images(O.data.cpu().numpy(), win='input')
            vis.images(B.data.cpu().numpy(), win='groundtruth')

        if i % 20 == 0:
            save_image(name='test',
                       img_lists=[O.cpu(), prediction.cpu(),
                                  B.cpu()],
                       path=cfg['show_dir'],
                       step=i,
                       batch_size=cfg['batch_size'])

    for key, val in all_losses.items():
        logger.info('total loss %s: %f' % (key, val / all_num))
示例#14
0
def main(hparams):
    hparams.n_input = np.prod(hparams.image_shape)
    images = model_input(hparams)

    estimators = utils.get_estimators(hparams)
    utils.setup_checkpointing(hparams)
    measurement_losses, l2_losses = utils.load_checkpoints(hparams)

    est_images = {model_type: {} for model_type in hparams.model_types}
    for i, image in images.iteritems():

        if not hparams.not_lazy:
            # If lazy, first check if the image has already been
            # saved before by *all* estimators. If yes, then skip this image.
            save_paths = utils.get_save_paths(hparams, i)
            if all([
                    os.path.isfile(save_path)
                    for save_path in save_paths.values()
            ]):
                continue

        # Reshape input
        x_val = image.reshape(1, hparams.n_input)

        # Construct noise and measurements
        noise_val = hparams.noise_std * np.random.randn(
            1, hparams.num_measurements)
        A_val = np.random.randn(hparams.n_input, hparams.num_measurements)
        y_val = np.matmul(x_val, A_val) + noise_val

        # Construct estimates using each estimator
        print 'Processing image {0}'.format(i)
        for model_type in hparams.model_types:
            estimator = estimators[model_type]
            est_image = estimator(A_val, y_val, hparams)
            est_images[model_type][i] = est_image
            # Compute and store measurement and l2 loss
            measurement_losses[model_type][i] = utils.get_measurement_loss(
                est_image, A_val, y_val)
            l2_losses[model_type][i] = utils.get_l2_loss(est_image, image)

        # Checkpointing
        if (hparams.save_images) and ((i + 1) % 100 == 0):
            utils.checkpoint(est_images, measurement_losses, l2_losses,
                             save_image, hparams)
            est_images = {model_type: {} for model_type in hparams.model_types}
            print '\nProcessed and saved first ', i + 1, 'images\n'

    # Final checkpoint
    if hparams.save_images:
        utils.checkpoint(est_images, measurement_losses, l2_losses, save_image,
                         hparams)
        print '\nProcessed and saved all {0} images\n'.format(len(images))

    if hparams.print_stats:
        for model_type in hparams.model_types:
            print model_type
            mean_m_loss = np.mean(measurement_losses[model_type].values())
            mean_l2_loss = np.mean(l2_losses[model_type].values())
            print 'mean measurement loss = {0}'.format(mean_m_loss)
            print 'mean l2 loss = {0}'.format(mean_l2_loss)

    if hparams.image_matrix > 0:
        utils.image_matrix(images, est_images, view_image, hparams)
示例#15
0
def main(hparams):

    # Set up some stuff accoring to hparams
    hparams.n_input = np.prod(hparams.image_shape)
    utils.set_num_measurements(hparams)
    utils.print_hparams(hparams)
    
    if hparams.dataset == 'mnist':
        hparams.n_z = latent_dim
    elif hparams.dataset == 'celebA':
        hparams.z_dim = latent_dim 
    
    # get inputs
    xs_dict = model_input(hparams)

    estimators = utils.get_estimators(hparams)
    utils.setup_checkpointing(hparams)
    measurement_losses, l2_losses = utils.load_checkpoints(hparams)
    
    image_loss_mnist = []
    meas_loss_mnist = []
    x_hat_mnist = []
    x_hats_dict = {model_type : {} for model_type in hparams.model_types}
    x_batch_dict = {}
    for key, x in xs_dict.iteritems():
        if not hparams.not_lazy:
            # If lazy, first check if the image has already been
            # saved before by *all* estimators. If yes, then skip this image.
            save_paths = utils.get_save_paths(hparams, key)
            is_saved = all([os.path.isfile(save_path) for save_path in save_paths.values()])
            if is_saved:
                continue

        x_batch_dict[key] = x
        if len(x_batch_dict) < hparams.batch_size:
            continue

        # Reshape input
        x_batch_list = [x.reshape(1, hparams.n_input) for _, x in x_batch_dict.iteritems()]
        x_batch = np.concatenate(x_batch_list)

        # Construct noise and measurements
        A = utils.get_A(hparams)
        noise_batch = hparams.noise_std * np.random.randn(hparams.batch_size, hparams.num_measurements)
        if hparams.measurement_type == 'project':
            y_batch = x_batch + noise_batch
        else:
            measure = np.matmul(x_batch, A)
            y_batch = np.absolute(measure) + noise_batch

        # Construct estimates using each estimator
        for model_type in hparams.model_types:
            x_main_batch = 10000*np.ones_like(x_batch)
            for k in range(num_restarts):
                print "Restart #", str(k+1)

                # Solve deep pr problem with random initial iterate
                init_iter = np.random.randn(hparams.batch_size, latent_dim)

                # First gradient descent
                z_opt_batch = init_iter                
                estimator = estimators[model_type]
                items = estimator(A, y_batch, z_opt_batch, hparams)
                x_hat_batch1 = items[0]
                z_opt_batch1 = items[1]
                losses_val1  = items[2]
                x_hat_batch = x_hat_batch1
                x_hat_batch = utils.resolve_ambiguity(x_hat_batch, x_batch, hparams.batch_size)
       
                # Use reflection of initial iterate
                z_opt_batch2 = -1*init_iter
                items = estimator(A, y_batch, z_opt_batch2, hparams)
                x_hat_batch2 = items[0]
                z_opt_batch2 = items[1]
                losses_val2  = items[2]           
                x_hat_batch2 = utils.resolve_ambiguity(x_hat_batch2, x_batch, hparams.batch_size)

                x_hat_batchnew = utils.get_optimal_x_batch(x_hat_batch, x_hat_batch2, x_batch, hparams.batch_size)                
                x_main_batch = utils.get_optimal_x_batch(x_hat_batchnew, x_main_batch, x_batch, hparams.batch_size)

            x_hat_batch = x_main_batch
            if hparams.dataset == 'mnist':
                utils.print_stats(x_hat_batch, x_batch, hparams.batch_size)

            for i, key in enumerate(x_batch_dict.keys()):
                x = xs_dict[key]
                y = y_batch[i]
                x_hat = x_hat_batch[i]

                # Save the estimate
                x_hats_dict[model_type][key] = x_hat

                # Compute and store measurement and l2 loss
                measurement_losses[model_type][key] = utils.get_measurement_loss(x_hat, A, y)
                meas_loss_mnist.append(utils.get_measurement_loss(x_hat, A, y))
                l2_losses[model_type][key] = utils.get_l2_loss(x_hat, x)
                image_loss_mnist.append(utils.get_l2_loss(x_hat,x))
        print 'Processed upto image {0} / {1}'.format(key+1, len(xs_dict))

        # Checkpointing
        if (hparams.save_images) and ((key+1) % hparams.checkpoint_iter == 0):
            utils.checkpoint(x_hats_dict, measurement_losses, l2_losses, save_image, hparams)
            x_hats_dict = {model_type : {} for model_type in hparams.model_types}
            print '\nProcessed and saved first ', key+1, 'images\n'

        x_batch_dict = {}

    # Final checkpoint
    if hparams.save_images:
        utils.checkpoint(x_hats_dict, measurement_losses, l2_losses, save_image, hparams)
        print '\nProcessed and saved all {0} image(s)\n'.format(len(xs_dict))

    if hparams.print_stats:
        for model_type in hparams.model_types:
            mean_m_loss = np.mean(measurement_losses[model_type].values())
            mean_l2_loss = np.mean(l2_losses[model_type].values())
            print 'mean measurement loss = {0}'.format(mean_m_loss)
            print 'mean l2 loss = {0}'.format(mean_l2_loss)

    if hparams.image_matrix > 0:
        utils.image_matrix(xs_dict, x_hats_dict, view_image, hparams)

    # Warn the user that some things were not processsed
    if len(x_batch_dict) > 0:
        print '\nDid NOT process last {} images because they did not fill up the last batch.'.format(len(x_batch_dict))
        print 'Consider rerunning lazily with a smaller batch size.'
示例#16
0
    def __init__(self, *args):
        super(DeepUNetTrainer, self).__init__(*args)

        # log file
        if self.args.train:
            ctime = time.ctime().split()

            log_path = './log'
            if not os.path.exists(log_path):
                os.mkdir(log_path)

            log_dir = os.path.join(
                log_path,
                '%s_%s_%s_%s' % (ctime[-1], ctime[1], ctime[2], ctime[3]))
            os.mkdir(log_dir)
            with open(os.path.join(log_dir, 'arg.txt'), 'w') as f:
                f.write(str(args))
            self.log_file = open(os.path.join(log_dir, 'loss.txt'), 'w')

        self.save_path = './data/result'
        if not os.path.exists(self.save_path):
            os.mkdir(self.save_path)

        # build model
        self.generator = DeepUNetPaintGenerator().to(self.device)
        self.discriminator = PatchGAN(sigmoid=self.args.no_mse).to(self.device)

        # set optimizers
        self.optimizers = self._set_optimizers()

        # set loss functions
        self.losses = self._set_losses()

        # set image pooler
        self.image_pool = ImagePooling(50)

        # load pretrained model
        if self.args.pretrainedG != '':
            if self.args.verbose:
                print('load pretrained generator...')
            load_checkpoints(self.args.pretrainedG, self.generator,
                             self.optimizers['G'])
        if self.args.pretrainedD != '':
            if self.args.verbose:
                print('load pretrained discriminator...')
            load_checkpoints(self.args.pretrainedD, self.discriminator,
                             self.optimizers['D'])

        if self.device.type == 'cuda':
            # enable parallel computation
            self.generator = nn.DataParallel(self.generator)
            self.discriminator = nn.DataParallel(self.discriminator)

        # loss values for tracking
        self.loss_G_gan = AverageTracker('loss_G_gan')
        self.loss_G_l1 = AverageTracker('loss_G_l1')
        self.loss_D_real = AverageTracker('loss_D_real')
        self.loss_D_fake = AverageTracker('loss_D_fake')

        # image value
        self.imageA = None
        self.imageB = None
        self.fakeB = None