Esempio n. 1
0
def main(_):
    imglob = glob.glob(os.path.join(FLAGS.data_root, FLAGS.dataset, '*'))
    mses = {}
    psnrs = {}
    ssims = {}
    experiment_folder = get_experiment_folder()
    # save FLAGS to yml
    yaml.dump(FLAGS.flag_values_dict(),
              open(os.path.join(experiment_folder, 'FLAGS.yml'), 'w'))

    for i, im in enumerate(imglob):
        print('Image: ' + str(i))
        image_name = im.split('/')[-1].split('.')[0]
        img_dataset = dataio.ImageFile(im)
        img = PIL.Image.open(im)
        image_resolution = (img.size[1] // FLAGS.downscaling_factor,
                            img.size[0] // FLAGS.downscaling_factor)

        #run_name = image_name + '_layers' + str(layers) + '_units' + str(hidden_units) + '_model' + FLAGS.model_type
        coord_dataset = dataio.Implicit2DWrapper(img_dataset,
                                                 sidelength=image_resolution)

        dataloader = DataLoader(coord_dataset,
                                shuffle=True,
                                batch_size=FLAGS.batch_size,
                                pin_memory=True,
                                num_workers=0)
        linear_decay = {
            'img_loss': trainingAC.LinearDecaySchedule(1, 1, FLAGS.epochs)
        }
        # Define the model.
        if FLAGS.model_type == 'mlp':
            model = modules.SingleBVPNet_INR(
                type=FLAGS.activation,
                mode=FLAGS.encoding,
                sidelength=image_resolution,
                out_features=img_dataset.img_channels,
                hidden_features=FLAGS.hidden_dims,
                num_hidden_layers=FLAGS.hidden_layers,
                encoding_scale=FLAGS.encoding_scale,
                batch_norm=FLAGS.bn,
                ff_dims=FLAGS.ff_dims)
        elif FLAGS.model_type == 'multi_tapered':
            model = modules.MultiScale_INR(
                type=FLAGS.activation,
                mode=FLAGS.encoding,
                sidelength=image_resolution,
                out_features=img_dataset.img_channels,
                hidden_features=FLAGS.hidden_dims,
                num_hidden_layers=FLAGS.hidden_layers,
                encoding_scale=FLAGS.encoding_scale,
                tapered=True,
                downsample=False,
                ff_dims=FLAGS.ff_dims)
        elif FLAGS.model_type == 'multi':
            model = modules.MultiScale_INR(
                type=FLAGS.activation,
                mode=FLAGS.encoding,
                sidelength=image_resolution,
                out_features=img_dataset.img_channels,
                hidden_features=FLAGS.hidden_dims,
                num_hidden_layers=FLAGS.hidden_layers,
                encoding_scale=FLAGS.encoding_scale,
                tapered=False,
                downsample=False,
                ff_dims=FLAGS.ff_dims)
        elif FLAGS.model_type == 'parallel':
            model = modules.Parallel_INR(type=FLAGS.activation,
                                         mode=FLAGS.encoding,
                                         sidelength=image_resolution,
                                         out_features=img_dataset.img_channels,
                                         hidden_features=[
                                             FLAGS.hidden_dims // 4,
                                             FLAGS.hidden_dims // 2,
                                             FLAGS.hidden_dims
                                         ],
                                         num_hidden_layers=FLAGS.hidden_layers,
                                         encoding_scale=FLAGS.encoding_scale)

        elif FLAGS.model_type == 'mixture':
            model = modules.INR_Mixture(type=FLAGS.activation,
                                        mode=FLAGS.encoding,
                                        sidelength=image_resolution,
                                        out_features=img_dataset.img_channels,
                                        hidden_features=FLAGS.hidden_dims,
                                        num_hidden_layers=FLAGS.hidden_layers,
                                        encoding_scale=FLAGS.encoding_scale,
                                        batch_norm=FLAGS.bn,
                                        ff_dims=FLAGS.ff_dims,
                                        num_components=FLAGS.num_components)
        model.cuda()
        root_path = os.path.join(experiment_folder, image_name)

        # Define the loss
        if FLAGS.loss == 'mse':
            loss_fn = partial(loss_functions.image_mse, None)
        elif FLAGS.loss == 'log_mse':
            loss_fn = image_log_mse
        summary_fn = partial(siren_utils.write_image_summary, image_resolution)

        if FLAGS.model_type == 'parallel':
            trainingAC.train_phased(
                model=model,
                train_dataloader=dataloader,
                epochs=FLAGS.epochs,
                lr=FLAGS.lr,
                steps_til_summary=FLAGS.steps_til_summary,
                epochs_til_checkpoint=FLAGS.epochs_til_ckpt,
                model_dir=root_path,
                loss_fn=loss_fn,
                summary_fn=summary_fn,
                l1_reg=FLAGS.l1_reg,
                spec_reg=FLAGS.spec_reg,
                phased=FLAGS.phased,
                intermediate_losses=FLAGS.intermediate_losses)
        else:
            trainingAC.train(model=model,
                             train_dataloader=dataloader,
                             epochs=FLAGS.epochs,
                             lr=FLAGS.lr,
                             steps_til_summary=FLAGS.steps_til_summary,
                             epochs_til_checkpoint=FLAGS.epochs_til_ckpt,
                             model_dir=root_path,
                             loss_fn=loss_fn,
                             summary_fn=summary_fn,
                             l1_reg=FLAGS.l1_reg,
                             spec_reg=FLAGS.spec_reg,
                             loss_schedules=None)

        mse, ssim, psnr = utils.check_metrics_full(dataloader, model,
                                                   image_resolution)
        mses[image_name] = mse
        psnrs[image_name] = psnr
        ssims[image_name] = ssim
    metrics = {'mse': mses, 'psnr': psnrs, 'ssim': ssims}

    with open(os.path.join(experiment_folder, 'result.json'), 'w') as fp:
        json.dump(metrics, fp)
Esempio n. 2
0
 if 'bn' not in TRAINING_FLAGS:
     TRAINING_FLAGS['bn'] = False
 if 'intermediate_losses' not in TRAINING_FLAGS:
     TRAINING_FLAGS['intermediate_losses'] = False
     if 'phased' not in TRAINING_FLAGS:
         TRAINING_FLAGS['phased'] = False
 if 'ff_dims' not in TRAINING_FLAGS:
     TRAINING_FLAGS['ff_dims'] = None
 if 'num_components' not in TRAINING_FLAGS:
     TRAINING_FLAGS['num_components'] = 1
 if TRAINING_FLAGS['model_type'] == 'mlp':
     model = modules.SingleBVPNet_INR(
         type=TRAINING_FLAGS['activation'],
         mode=TRAINING_FLAGS['encoding'],
         sidelength=image_resolution,
         out_features=img_dataset.img_channels,
         hidden_features=TRAINING_FLAGS['hidden_dims'],
         num_hidden_layers=TRAINING_FLAGS['hidden_layers'],
         encoding_scale=s,
         batch_norm=TRAINING_FLAGS['bn'],
         ff_dims=TRAINING_FLAGS['ff_dims'])
 elif TRAINING_FLAGS['model_type'] == 'multi_tapered':
     model = modules.MultiScale_INR(
         type=TRAINING_FLAGS['activation'],
         mode=TRAINING_FLAGS['encoding'],
         sidelength=image_resolution,
         out_features=img_dataset.img_channels,
         hidden_features=TRAINING_FLAGS['hidden_dims'],
         num_hidden_layers=TRAINING_FLAGS['hidden_layers'],
         encoding_scale=s,
         tapered=True,
         ff_dims=TRAINING_FLAGS['ff_dims'])
Esempio n. 3
0
def main(_):
    imglob_maml = glob.glob(
        os.path.join(FLAGS.data_root, FLAGS.maml_dataset, '*'))
    imglob = glob.glob(os.path.join(FLAGS.data_root, FLAGS.dataset, '*'))
    mses = {}
    psnrs = {}
    ssims = {}
    experiment_folder = get_experiment_folder()
    maml_folder = get_maml_folder()
    # save FLAGS to yml
    yaml.dump(FLAGS.flag_values_dict(),
              open(os.path.join(experiment_folder, 'FLAGS.yml'), 'w'))
    img_dataset = []
    for i, im in enumerate(imglob_maml):
        image_name = im.split('/')[-1].split('.')[0]
        img_dataset.append(dataio.ImageFile(im))
        img = PIL.Image.open(im)
        image_resolution = (img.size[1] // FLAGS.downscaling_factor,
                            img.size[0] // FLAGS.downscaling_factor)

        # run_name = image_name + '_layers' + str(layers) + '_units' + str(hidden_units) + '_model' + FLAGS.model_type
    coord_dataset = dataio.Implicit2DListWrapper(img_dataset,
                                                 sidelength=image_resolution)

    dataloader = DataLoader(coord_dataset,
                            shuffle=True,
                            batch_size=FLAGS.batch_size,
                            pin_memory=True,
                            num_workers=0)
    # linear_decay = {'img_loss': trainingAC.LinearDecaySchedule(1, 1, FLAGS.epochs)}
    # Define the model.
    if FLAGS.model_type == 'mlp':
        model = modules.SingleBVPNet_INR(
            type=FLAGS.activation,
            mode=FLAGS.encoding,
            sidelength=image_resolution,
            out_features=img_dataset[0].img_channels,
            hidden_features=FLAGS.hidden_dims,
            num_hidden_layers=FLAGS.hidden_layers,
            encoding_scale=FLAGS.encoding_scale,
            batch_norm=FLAGS.bn,
            ff_dims=FLAGS.ff_dims)
    elif FLAGS.model_type == 'multi_tapered':
        model = modules.MultiScale_INR(
            type=FLAGS.activation,
            mode=FLAGS.encoding,
            sidelength=image_resolution,
            out_features=img_dataset[0].img_channels,
            hidden_features=FLAGS.hidden_dims,
            num_hidden_layers=FLAGS.hidden_layers,
            encoding_scale=FLAGS.encoding_scale,
            tapered=True,
            downsample=False,
            ff_dims=FLAGS.ff_dims)
    elif FLAGS.model_type == 'multi':
        model = modules.MultiScale_INR(
            type=FLAGS.activation,
            mode=FLAGS.encoding,
            sidelength=image_resolution,
            out_features=img_dataset[0].img_channels,
            hidden_features=FLAGS.hidden_dims,
            num_hidden_layers=FLAGS.hidden_layers,
            encoding_scale=FLAGS.encoding_scale,
            tapered=False,
            downsample=False,
            ff_dims=FLAGS.ff_dims)
    elif FLAGS.model_type == 'parallel':
        model = modules.Parallel_INR(type=FLAGS.activation,
                                     mode=FLAGS.encoding,
                                     sidelength=image_resolution,
                                     out_features=img_dataset[0].img_channels,
                                     hidden_features=[
                                         FLAGS.hidden_dims // 4,
                                         FLAGS.hidden_dims // 2,
                                         FLAGS.hidden_dims
                                     ],
                                     num_hidden_layers=FLAGS.hidden_layers,
                                     encoding_scale=FLAGS.encoding_scale)

    elif FLAGS.model_type == 'mixture':
        model = modules.INR_Mixture(type=FLAGS.activation,
                                    mode=FLAGS.encoding,
                                    sidelength=image_resolution,
                                    out_features=img_dataset[0].img_channels,
                                    hidden_features=FLAGS.hidden_dims,
                                    num_hidden_layers=FLAGS.hidden_layers,
                                    encoding_scale=FLAGS.encoding_scale,
                                    batch_norm=FLAGS.bn,
                                    ff_dims=FLAGS.ff_dims,
                                    num_components=FLAGS.num_components)
    # exp_root = 'exp/maml'
    # experiment_names = [i.split('/')[-4] for i in
    #                     glob.glob(exp_root + '/KODAK21_epochs10000_lr0.0001_mlp_hdims100_hlayer2_mlp_sine_l1_reg0.001/maml/checkpoints/')]
    # state_dict = torch.load(os.path.join('KODAK21_epochs10000_lr0.0001_mlp_hdims100_hlayer2_mlp_sine_l1_reg0.001', 'maml' + '/checkpoints/model_best_.pth'), map_location='cpu').load_state_dict(state_dict, strict=True)
    model.cuda()
    root_path = maml_folder

    # Define the loss
    if FLAGS.loss == 'mse':
        loss_fn = partial(loss_functions.image_mse, None)
    elif FLAGS.loss == 'log_mse':
        loss_fn = image_log_mse
    summary_fn = partial(siren_utils.write_image_summary, image_resolution)

    try:
        state_dict = torch.load(os.path.join(maml_folder,
                                             'checkpoints/model_maml.pth'),
                                map_location='cpu')
        model.load_state_dict(state_dict, strict=True)

    except:
        print("No matching model found, training from scratch.")
        yaml.dump(FLAGS.flag_values_dict(),
                  open(os.path.join(maml_folder, 'FLAGS.yml'), 'w'))
        trainingMAML.train(model=model,
                           train_dataloader=dataloader,
                           maml_iterations=FLAGS.maml_iterations,
                           inner_lr=FLAGS.inner_lr,
                           outer_lr=FLAGS.outer_lr,
                           steps_til_summary=FLAGS.steps_til_summary,
                           epochs_til_checkpoint=FLAGS.epochs_til_ckpt,
                           model_dir=root_path,
                           loss_fn=loss_fn,
                           summary_fn=summary_fn,
                           maml_batch_size=FLAGS.maml_batch_size,
                           maml_adaptation_steps=FLAGS.maml_adaptation_steps)

    ref_model = copy.deepcopy(model)
    l1_loss_fn = partial(losses.model_l1_diff, ref_model)

    torch.save(model.state_dict(),
               os.path.join(experiment_folder, 'model_maml.pth'))
    for i, im in enumerate(imglob):
        print('Image: ' + str(i))
        image_name = im.split('/')[-1].split('.')[0]
        img_dataset = dataio.ImageFile(im)
        img = PIL.Image.open(im)
        image_resolution = (img.size[1] // FLAGS.downscaling_factor,
                            img.size[0] // FLAGS.downscaling_factor)

        # run_name = image_name + '_layers' + str(layers) + '_units' + str(hidden_units) + '_model' + FLAGS.model_type
        coord_dataset = dataio.Implicit2DWrapper(img_dataset,
                                                 sidelength=image_resolution)

        root_path = os.path.join(experiment_folder, image_name)
        dataloader = DataLoader(coord_dataset,
                                shuffle=True,
                                batch_size=FLAGS.batch_size,
                                pin_memory=True,
                                num_workers=0)
        trainingAC.train(model=model,
                         train_dataloader=dataloader,
                         epochs=FLAGS.epochs,
                         lr=FLAGS.lr,
                         steps_til_summary=FLAGS.steps_til_summary,
                         epochs_til_checkpoint=FLAGS.epochs_til_ckpt,
                         model_dir=root_path,
                         loss_fn=loss_fn,
                         l1_loss_fn=l1_loss_fn,
                         summary_fn=summary_fn,
                         l1_reg=FLAGS.l1_reg,
                         spec_reg=FLAGS.spec_reg)

        mse, ssim, psnr = utils.check_metrics_full(dataloader, model,
                                                   image_resolution)
        # mses[image_name] = mse
        # psnrs[image_name] = psnr
        # ssims[image_name] = ssim
        metrics = {'mse': mse, 'psnr': psnr, 'ssim': ssim}

        with open(os.path.join(root_path, 'result.json'), 'w') as fp:
            json.dump(metrics, fp)