Пример #1
0
def main(cfg, num_workers):
    # Shortened
    out_dir = cfg['training']['out_dir']
    batch_size = cfg['training']['batch_size']
    utils.save_config(os.path.join(out_dir, 'config.yml'), cfg)

    model_selection_metric = cfg['training']['model_selection_metric']
    model_selection_sign = 1 if cfg['training'][
        'model_selection_mode'] == 'maximize' else -1

    # Output directory
    utils.cond_mkdir(out_dir)

    # Dataset
    test_dataset = config.get_dataset('test', cfg)

    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=batch_size,
                                              num_workers=num_workers,
                                              shuffle=False)

    # Model
    model = config.get_model(cfg)
    trainer = config.get_trainer(model, None, cfg)

    # Print model
    print(model)
    logger = logging.getLogger(__name__)
    logger.info(
        f'Total number of parameters: {sum(p.numel() for p in model.parameters())}'
    )

    ckp = checkpoints.CheckpointIO(out_dir, model, None, cfg)
    try:
        load_dict = ckp.load('model_best.pt')
        logger.info('Model loaded')
    except FileExistsError:
        logger.info('Model NOT loaded')
        load_dict = dict()

    metric_val_best = load_dict.get('loss_val_best',
                                    -model_selection_sign * np.inf)

    logger.info(
        f'Current best validation metric ({model_selection_metric}): {metric_val_best:.6f}'
    )

    eval_dict = trainer.evaluate(test_loader)
    metric_val = eval_dict[model_selection_metric]
    logger.info(
        f'Validation metric ({model_selection_metric}): {metric_val:.8f}')

    eval_dict_path = os.path.join(out_dir, 'eval_dict.yml')
    with open(eval_dict_path, 'w') as f:
        yaml.dump(config, f)

    print(f'Results saved in {eval_dict_path}')
Пример #2
0
def main():
    args = get_args()
    random.seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    os.environ["CUDA_VISIBLE_DEVICES"] = args.CVDs
    cudnn.benchmark = True

    current_time, f, save_path, writer = makedir(args)
    args.train_set_old = '/data0/zili/code/checkpoints/Jun05_18-25-14'
    if args.server == 15:
        args.train_set_old = '/data0/share/zili/checkpoints/Jun09_17-21-53'

    model, preserved = get_model(args)
    model = model.cuda()
    if args.increment_phase > 0:
        ## frezee conv layer parameter
        for p in model.children():
            for c in p.parameters():
                c.requires_grad = False
            break

    optimizer, scheduler, criterion, embedding_loss_function = get_osc(
        args, model)
    criterion, embedding_loss_function = criterion.cuda(
    ), embedding_loss_function.cuda()

    sampler_train_loader, train_loader, test_loader, sampler_train_loader_old, train_loader_old, classes = get_dataloader(
        args)

    trainer = Trainer(args, optimizer, scheduler, sampler_train_loader,
                      train_loader, test_loader, model, preserved,
                      sampler_train_loader_old, train_loader_old, criterion,
                      embedding_loss_function, writer, f, save_path, classes)
    trainer.run()

    f.close()
    writer.close()
Пример #3
0
def complete(args):
    """
    Performs in-painting over images using a pre-trained
    GAN model : http://arxiv.org/abs/1607.07539

    """

    image_paths = get_image_paths(args.images)
    nImgs = len(image_paths)

    print('Images found : {}'.format(nImgs))

    if args.dataset == 'celeba':
        crop = True
        n_channels = 3
    else:
        n_channels = 1
        crop = False

    image_shape = [int(args.image_size), int(args.image_size), n_channels]

    batch_idxs = int(np.ceil(nImgs / args.batch_size))

    maskType = args.maskType

    folder_name = os.path.join('completions', args.dataset.lower(),
                               args.model.lower())
    dumpDir = os.path.join(folder_name, maskType)

    if os.path.exists(dumpDir):
        shutil.rmtree(dumpDir)
    os.makedirs(dumpDir)

    if maskType == 'random':
        fraction_masked = 0.2
        mask = np.ones(image_shape)
        mask[np.random.random(image_shape[:2]) < fraction_masked] = 0.0
    elif maskType == 'center':  # Center mask removes 25% of the image
        patch_size = args.image_size // 2
        crop_pos = (args.image_size - patch_size) / 2
        mask = np.ones(image_shape)
        sz = args.image_size
        l = int(crop_pos)
        u = int(crop_pos + patch_size)
        mask[l:u, l:u, :] = 0.0
    elif maskType == 'left':
        mask = np.ones(image_shape)
        c = args.image_size // 2
        mask[:, :c, :] = 0.0
    elif maskType == 'full':
        mask = np.ones(image_shape)
    elif maskType == 'grid':
        mask = np.zeros(image_shape)
        mask[::4, ::4, :] = 1.0
    elif maskType == 'bottom':
        mask = np.ones(image_shape)
        bottom_half = int(args.image_size / 2)
        mask[bottom_half:args.image_size, :, :] = 0.0

    else:
        print('Invalid mask type provided')
        assert (False)

    tf_config = tf.ConfigProto()

    with tf.Session(config=tf_config) as sess:
        try:
            tf.global_variables_initializer().run()
        except:
            tf.initialize_all_variables().run()
        # Load the checkpoint file into the model
        # Get the model
        # If Training is set to false, the discriminator ops graph is not built.
        # The discriminator graph is used to compute the in-painting loss. Hacked it now, please FIX THIS - TODO

        if args.model.lower() == 'dragan' or args.model.lower(
        ) == 'dcgan-cons':  # Pick the non-BN version of DRAGAN and DCGAN-CONS
            model = config.get_model(args.model.upper(),
                                     args.model.lower(),
                                     training=True,
                                     batch_norm=False,
                                     image_shape=image_shape)
        else:
            model = config.get_model(args.model.upper(),
                                     args.model.lower(),
                                     training=True,
                                     image_shape=image_shape)

        restorer = tf.train.Saver()
        checkpoint_dir = os.path.join(os.getcwd(), 'checkpoints',
                                      args.dataset.lower(), args.model.lower())
        ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            restorer.restore(sess, ckpt.model_checkpoint_path)
        else:
            print('Invalid checkpoint directory')
            assert (False)

        critic = False

        if model.name == 'wgan' or model.name == 'wgan-gp':
            critic = True

        for idx in range(0, batch_idxs):
            l = idx * args.batch_size
            u = min((idx + 1) * args.batch_size, nImgs)
            batchSz = u - l
            batch_files = image_paths[l:u]
            batch = [
                read_image(batch_file, [args.image_size, args.image_size],
                           n_channels=n_channels,
                           crop=crop) for batch_file in batch_files
            ]
            batch_images = np.array(batch).astype(np.float32)
            masked_images = np.multiply(batch_images, mask)
            if batchSz < args.batch_size:
                padSz = ((0, int(args.batch_size - batchSz)), (0, 0), (0, 0),
                         (0, 0))
                batch_images = np.pad(batch_images, padSz, 'constant')
                batch_images = batch_images.astype(np.float32)
                masked_images = np.multiply(batch_images, mask)

            zhats = np.random.uniform(-1,
                                      1,
                                      size=(args.batch_size, model.z_dim))

            disc_score_tracker = [
            ]  #tracks the discriminator/critic score for every image in the batch
            perceptual_loss_grad = []
            contextual_loss_grad = []
            # Variables for ADAM
            m = 0
            v = 0

            for file_idx in range(len(batch_images)):
                folder_idx = l + file_idx

                outDir = os.path.join(dumpDir, '{}'.format(folder_idx))
                os.makedirs(
                    outDir
                )  # Directory that stores real and masked images, different for each real image

                if n_channels == 3:
                    genDir = os.path.join(
                        outDir, 'gen_images'
                    )  # Directory that stores iterations of in-paintings
                    os.makedirs(genDir)

                genDir_overlay = os.path.join(
                    outDir, 'gen_images_overlay'
                )  # Directory that stores iterations of in-paintings
                os.makedirs(genDir_overlay)

                gzDir = os.path.join(outDir, 'gz')
                os.makedirs(gzDir)

                save_image(image=batch_images[file_idx, :, :, :],
                           path=os.path.join(outDir, 'original.jpg'),
                           n_channels=n_channels)
                save_image(image=masked_images[file_idx, :, :, :],
                           path=os.path.join(outDir, 'masked.jpg'),
                           n_channels=n_channels)

            for i in range(args.nIter):
                fd = {
                    model.z: zhats,
                    model.mask: mask,
                    model.X: batch_images,
                }
                run = [
                    model.complete_loss, model.perceptual_loss,
                    model.contextual_loss, model.grad_complete_loss, model.G,
                    model.grad_norm_perceptual_loss,
                    model.grad_norm_contextual_loss
                ]
                complete_loss, perceptual_loss, contextual_loss, g, G_imgs, grad_norm_perceptual_loss, grad_norm_contextual_loss = sess.run(
                    run, feed_dict=fd)

                #Capture the gradient norms of both loss components
                perceptual_loss_grad.append(grad_norm_perceptual_loss[0])
                contextual_loss_grad.append(grad_norm_contextual_loss[0])

                if model.name != 'wgan' and model.name != 'wgan-gp':
                    disc_scores = sess.run(model.D_fake_prob,
                                           feed_dict={model.z: zhats})
                else:
                    disc_scores = sess.run(model.C_fake,
                                           feed_dict={model.z: zhats})

                disc_score_tracker.append(disc_scores.flatten())

                if i % 100 == 0:
                    # Compute mean score given to this batch of images by the Discriminator
                    if model.name != 'wgan' or model.name.lower() != 'wgan-gp':
                        mean_disc_score = np.mean(disc_scores)
                    else:
                        mean_disc_score = np.mean(disc_scores)

                    print(
                        'Timestamp: {:%Y-%m-%d %H:%M:%S} Batch : {}/{}. Iteration : {}. Mean complete loss : {} Mean Perceptual loss : {} Mean Contextual Loss: {} Discriminator/Critic Score: {}'
                        .format(datetime.now(), idx, batch_idxs, i,
                                np.mean(complete_loss[0:batchSz]),
                                perceptual_loss,
                                np.mean(contextual_loss[0:batchSz]),
                                mean_disc_score))

                    inv_masked_hat_images = np.multiply(G_imgs, 1.0 - mask)
                    completed = []

                    #Direct overlay
                    overlay = masked_images + inv_masked_hat_images

                    #Poisson Blending
                    if n_channels == 3:  # OpenCV Poisson Blending supports only 3-channel image blending. FIXME
                        for img, indx in zip(G_imgs, range(len(G_imgs))):
                            completed.append(
                                blend_images(image=overlay[indx, :, :, :],
                                             gen_image=img,
                                             mask=np.multiply(255,
                                                              1.0 - mask)))
                        completed = np.asarray(completed)

                    # Save all in-painted images of this iteration in their respective image folders
                    for image_idx in range(args.batch_size):
                        folder_idx = l + image_idx

                        save_path_overlay = os.path.join(
                            dumpDir, '{}'.format(folder_idx),
                            'gen_images_overlay', 'gen_{}.jpg'.format(i))
                        save_path_gz = os.path.join(dumpDir,
                                                    '{}'.format(folder_idx),
                                                    'gz',
                                                    'gz_{}.jpg'.format(i))
                        overlay[image_idx, :, :, :] = rescale_image(
                            overlay[image_idx, :, :, :])

                        if n_channels == 3:
                            save_path = os.path.join(dumpDir,
                                                     '{}'.format(folder_idx),
                                                     'gen_images',
                                                     'gen_{}.jpg'.format(i))
                            save_image(image=completed[image_idx, :, :, :],
                                       path=save_path,
                                       n_channels=n_channels)

                        save_image(image=overlay[image_idx, :, :, :],
                                   path=save_path_overlay,
                                   n_channels=n_channels)
                        save_image(image=rescale_image(
                            G_imgs[image_idx, :, :, :]),
                                   path=save_path_gz,
                                   n_channels=n_channels)

                # Adam implementation
                m_prev = np.copy(m)
                v_prev = np.copy(v)
                m = args.beta1 * m_prev + (1 - args.beta1) * g[0]
                v = args.beta2 * v_prev + (1 - args.beta2) * np.multiply(
                    g[0], g[0])
                m_hat = m / (1 - args.beta1**(i + 1))
                v_hat = v / (1 - args.beta2**(i + 1))
                zhats += -np.true_divide(args.lr * m_hat,
                                         (np.sqrt(v_hat) + args.eps))

                sys.stdout.flush()

                if args.clipping == 'standard':
                    # Standard Clipping
                    zhats = np.clip(zhats, -1, 1)
                elif args.clipping == 'stochastic':
                    # Stochastic Clipping
                    for batch_zhat, batch_id in zip(zhats,
                                                    range(zhats.shape[0])):
                        for elem, elem_id in zip(batch_zhat,
                                                 range(batch_zhat.shape[0])):
                            if elem > 1 or elem < -1:
                                zhats[batch_id][elem_id] = np.random.uniform(
                                    -1, 1
                                )  # FIXME : There has to be a less shitty way to modify an array in-place
                else:
                    print('Invalid clipping mode')
                    assert (False)

            #Save the matrix for the batch once done
            disc_score_tracker = np.asarray(disc_score_tracker)
            perceptual_loss_grad = np.asarray(perceptual_loss_grad)
            contextual_loss_grad = np.asarray(contextual_loss_grad)

            with open(
                    os.path.join(dumpDir,
                                 'disc_scores_batch_{}.pkl'.format(idx)),
                    'wb') as f:
                pickle.dump(disc_score_tracker, f)

            with open(
                    os.path.join(dumpDir,
                                 'p_loss_grad_batch_{}.pkl'.format(idx)),
                    'wb') as f:
                pickle.dump(perceptual_loss_grad, f)

            with open(
                    os.path.join(dumpDir,
                                 'c_loss_grad_batch_{}.pkl'.format(idx)),
                    'wb') as f:
                pickle.dump(contextual_loss_grad, f)
Пример #4
0
import torch
import argparse
from config import get_model, get_faces
from data.obj_utils import write_obj

model = get_model()
f = get_faces()


def export(x):
    write_obj('learning_out.obj', x, f)

def decode(z):
    with torch.set_grad_enabled(False):
        model.eval()
        z_tensor = torch.tensor([z]).view(1, len(z))
        x = model.decode(z_tensor)
        x = x[0].transpose(1, 0)
        export(x)
        return x


if __name__ == "__main__":

    parser = argparse.ArgumentParser(description='Process some integers.')
    parser.add_argument('latents', metavar='N', type=float, nargs='+')
    args = parser.parse_args()
    decode(args.latents)

        coord.join(threads)
        summary_writer.close()
        pbar.close()


if __name__ == "__main__":
    parser = build_parser()
    FLAGS = parser.parse_args()
    FLAGS.model = FLAGS.model.upper()
    FLAGS.dataset = FLAGS.dataset.lower()
    if FLAGS.name is None:
        FLAGS.name = FLAGS.model.lower()
    config.pprint_args(FLAGS)

    # get information for dataset
    dataset_pattern, n_examples = config.get_dataset(FLAGS.dataset)

    # input pipeline
    X = input_pipeline(dataset_pattern,
                       batch_size=FLAGS.batch_size,
                       num_threads=FLAGS.num_threads,
                       num_epochs=FLAGS.num_epochs)
    model = config.get_model(FLAGS.model, FLAGS.name, training=True)
    train(model=model,
          input_op=X,
          num_epochs=FLAGS.num_epochs,
          batch_size=FLAGS.batch_size,
          n_examples=n_examples,
          renew=FLAGS.renew)
Пример #6
0

'''
You can create a gif movie through imagemagick on the commandline:
$ convert -delay 20 eval/* movie.gif
'''
# def to_gif(dir_name='eval'):
#     images = []
#     for path in glob.glob(os.path.join(dir_name, '*.png')):
#         im = scipy.misc.imread(path)
#         images.append(im)

#     # make_gif(images, dir_name + '/movie.gif', duration=10, true_image=True)
#     imageio.mimsave('movie.gif', images, duration=0.2)

if __name__ == "__main__":
    parser = build_parser()
    FLAGS = parser.parse_args()
    FLAGS.model = FLAGS.model.upper()
    FLAGS.dataset = FLAGS.dataset.lower()
    if FLAGS.name is None:
        FLAGS.name = FLAGS.model.lower()
    config.pprint_args(FLAGS)

    N = int(FLAGS.sample)
    rep = int(FLAGS.rep)

    # training=False => build generator only
    model = config.get_model(FLAGS.model, FLAGS.name.upper(), training=False)
    eval_dump(model, name=FLAGS.name.upper(), dataset=FLAGS.dataset, rep=5)
Пример #7
0
                       num_threads=FLAGS.num_threads,
                       num_epochs=FLAGS.num_epochs,
                       image_size=FLAGS.image_size,
                       dataset=FLAGS.dataset)

    # Arbitrarily sized crops will be resized to 64x64x3. Model will be constructed accordingly

    image_shape = [FLAGS.image_size, FLAGS.image_size, n_channels]
    batch_norm = True

    if FLAGS.name == 'dragan' or FLAGS.name == 'dcgan-cons':
        batch_norm = False

    if FLAGS.simultaneous == True and FLAGS.model == 'DCGAN':
        FLAGS.name = FLAGS.model.lower() + '_sim'

    model = config.get_model(FLAGS.model,
                             FLAGS.name,
                             training=True,
                             image_shape=image_shape,
                             batch_norm=batch_norm)
    train(model=model,
          dataset=FLAGS.dataset,
          input_op=X,
          num_epochs=FLAGS.num_epochs,
          batch_size=FLAGS.batch_size,
          n_examples=n_examples,
          ckpt_step=FLAGS.ckpt_step,
          renew=FLAGS.renew,
          simultaneous=FLAGS.simultaneous)
Пример #8
0
#         images.append(im)

#     # make_gif(images, dir_name + '/movie.gif', duration=10, true_image=True)
#     imageio.mimsave('movie.gif', images, duration=0.2)

if __name__ == "__main__":
    parser = build_parser()
    FLAGS = parser.parse_args()
    FLAGS.model = FLAGS.model.upper()
    FLAGS.dataset = FLAGS.dataset.lower()
    if FLAGS.name is None:
        FLAGS.name = FLAGS.model.lower()
    config.pprint_args(FLAGS)

    if FLAGS.model.lower() == 'dragan' or FLAGS.model.lower(
    ) == 'dcgan-cons':  # Pick the non-BN version of DRAGAN and DCGAN-CONS
        model = config.get_model(FLAGS.model.upper(),
                                 FLAGS.model.lower(),
                                 training=True,
                                 batch_norm=False)
    else:
        model = config.get_model(FLAGS.model.upper(),
                                 FLAGS.model.lower(),
                                 training=True)

    eval(model,
         dataset=FLAGS.dataset,
         name=FLAGS.name,
         batch_size=FLAGS.batch_size,
         load_all_ckpt=True)
Пример #9
0
$ convert -delay 20 eval/* movie.gif
'''


def to_gif(dir_name='eval'):
    images = []
    im_list = []
    # for path in glob.glob('*.jpg'):
    #     im_list.append(path)
    for i in range(9):
        im = scipy.misc.imread(dir_name + '/' + str(i + 1) + '.jpg')
        im = scipy.misc.imresize(im, [256, 256])
        images.append(im)

    # make_gif(images, dir_name + '/movie.gif', duration=10, true_image=True)
    imageio.mimsave('movie.gif', images, duration=0.4)


if __name__ == "__main__":
    parser = build_parser()
    FLAGS = parser.parse_args()
    FLAGS.model = FLAGS.model.upper()
    if FLAGS.name is None:
        FLAGS.name = FLAGS.model.lower()
    config.pprint_args(FLAGS)

    N = FLAGS.sample_size**0.5
    assert N == int(N), 'sample size should be a square number'
    model = config.get_model(FLAGS.model, FLAGS.name, training=False)
    eval(model, name=FLAGS.name, sample_shape=[1, 1], load_all_ckpt=True)
Пример #10
0
 def model(self, x, is_training):
     logits, output, self.minLR, self.maxLR, self.step_factor, self.weight_decay = config.get_model(
         self.model_name, x, is_training)
     return logits, output
Пример #11
0
def main(cfg, num_workers):
    # Shortened
    out_dir = cfg['training']['out_dir']
    batch_size = cfg['training']['batch_size']
    backup_every = cfg['training']['backup_every']
    utils.save_config(os.path.join(out_dir, 'config.yml'), cfg)

    model_selection_metric = cfg['training']['model_selection_metric']
    model_selection_sign = 1 if cfg['training'][
        'model_selection_mode'] == 'maximize' else -1

    # Output directory
    utils.cond_mkdir(out_dir)

    # Dataset
    train_dataset = config.get_dataset('train', cfg)
    val_dataset = config.get_dataset('val', cfg)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               num_workers=num_workers,
                                               shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=batch_size,
                                             num_workers=num_workers,
                                             shuffle=False)

    # Model
    model = config.get_model(cfg)
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    trainer = config.get_trainer(model, optimizer, cfg)

    # Print model
    print(model)
    logger = logging.getLogger(__name__)
    logger.info(
        f'Total number of parameters: {sum(p.numel() for p in model.parameters())}'
    )

    # load pretrained model
    tb_logger = tensorboardX.SummaryWriter(os.path.join(out_dir, 'logs'))
    ckp = checkpoints.CheckpointIO(out_dir, model, optimizer, cfg)
    try:
        load_dict = ckp.load('model_best.pt')
        logger.info('Model loaded')
    except FileExistsError:
        logger.info('Model NOT loaded')
        load_dict = dict()

    epoch_it = load_dict.get('epoch_it', -1)
    it = load_dict.get('it', -1)
    metric_val_best = load_dict.get('loss_val_best',
                                    -model_selection_sign * np.inf)

    logger.info(
        f'Current best validation metric ({model_selection_metric}): {metric_val_best:.6f}'
    )

    # Shortened
    print_every = cfg['training']['print_every']
    validate_every = cfg['training']['validate_every']
    max_iterations = cfg['training']['max_iterations']
    max_epochs = cfg['training']['max_epochs']

    while True:
        epoch_it += 1

        for batch in train_loader:
            it += 1
            loss_dict = trainer.train_step(batch)
            loss = loss_dict['total_loss']
            for k, v in loss_dict.items():
                tb_logger.add_scalar(f'train/{k}', v, it)

            # Print output
            if print_every > 0 and (it % print_every) == 0:
                logger.info(
                    f'[Epoch {epoch_it:02d}] it={it:03d}, loss={loss:.8f}')

            # Backup if necessary
            if backup_every > 0 and (it % backup_every) == 0:
                logger.info('Backup checkpoint')
                ckp.save(f'model_{it:d}.pt',
                         epoch_it=epoch_it,
                         it=it,
                         loss_val_best=metric_val_best)

            # Run validation
            if validate_every > 0 and (it % validate_every) == 0:
                eval_dict = trainer.evaluate(val_loader)
                print('eval_dict=\n', eval_dict)
                metric_val = eval_dict[model_selection_metric]
                logger.info(
                    f'Validation metric ({model_selection_metric}): {metric_val:.8f}'
                )

                for k, v in eval_dict.items():
                    tb_logger.add_scalar(f'val/{k}', v, it)

                if model_selection_sign * (metric_val - metric_val_best) > 0:
                    metric_val_best = metric_val
                    logger.info(f'New best model (loss {metric_val_best:.8f}')
                    ckp.save('model_best.pt',
                             epoch_it=epoch_it,
                             it=it,
                             loss_val_best=metric_val_best)

            if (0 < max_iterations <= it) or (0 < max_epochs <= epoch_it):
                logger.info(
                    f'Maximum iteration/epochs ({epoch_it}/{it}) reached. Exiting.'
                )
                ckp.save(f'model_{it:d}.pt',
                         epoch_it=epoch_it,
                         it=it,
                         loss_val_best=metric_val_best)
                exit(3)