Ejemplo n.º 1
0
def save_checkpoint(ckpt_file, model, optimiser, beta, geco_err_ema,
                    iter_idx, verbose=True):
    if verbose:
        fprint(f"Saving model training checkpoint to: {ckpt_file}")
    ckpt_dict = {'model_state_dict': model.state_dict(),
                 'optimiser_state_dict': optimiser.state_dict(),
                 'beta': beta,
                 'iter_idx': iter_idx}
    if geco_err_ema is not None:
        ckpt_dict['err_ema'] = geco_err_ema
    torch.save(ckpt_dict, ckpt_file)
Ejemplo n.º 2
0
def load(cfg, **unused_kwargs):
    del unused_kwargs
    if not os.path.exists(cfg.data_folder):
        raise Exception("Data folder does not exist.")
    print(f"Using {cfg.num_workers} data workers.")

    # Copy all images and splits to /tmp
    if cfg.copy_to_tmp:
        for directory in ['/recordings', '/splits']:
            src = cfg.data_folder + directory
            dst = '/tmp' + directory
            fprint(f"Copying dataset from {src} to {dst}.")
            copytree(src, dst)
        cfg.data_folder = '/tmp'

    # Training
    tng_set = ShapeStacksDataset(cfg.data_folder,
                                 cfg.split_name,
                                 'train',
                                 cfg.img_size)
    tng_loader = DataLoader(tng_set,
                            batch_size=cfg.batch_size,
                            shuffle=True,
                            num_workers=cfg.num_workers)
    # Validation
    val_set = ShapeStacksDataset(cfg.data_folder,
                                 cfg.split_name,
                                 'eval',
                                 cfg.img_size)
    val_loader = DataLoader(val_set,
                            batch_size=cfg.batch_size,
                            shuffle=False,
                            num_workers=cfg.num_workers)
    # Test
    tst_set = ShapeStacksDataset(cfg.data_folder,
                                 cfg.split_name,
                                 'test',
                                 cfg.img_size,
                                 shuffle_files=cfg.shuffle_test)
    tst_loader = DataLoader(tst_set,
                            batch_size=1,
                            shuffle=False,
                            num_workers=1)

    # Throughput stats
    loader_throughput(tng_loader)

    return (tng_loader, val_loader, tst_loader)
Ejemplo n.º 3
0
def load(cfg, **unused_kwargs):
    # Fix TensorFlow seed
    global SEED
    SEED = cfg.seed
    tf.set_random_seed(SEED)

    if cfg.num_workers == 0:
        fprint("Need to use at least one worker for loading tfrecords.")
        cfg.num_workers = 1

    del unused_kwargs
    if not os.path.exists(cfg.data_folder):
        raise Exception("Data folder does not exist.")
    print(f"Using {cfg.num_workers} data workers.")
    # Create data iterators
    train_loader = GQNLoader(data_folder=cfg.data_folder,
                             mode='devel_train',
                             img_size=cfg.img_size,
                             val_frac=cfg.val_frac,
                             batch_size=cfg.batch_size,
                             num_workers=cfg.num_workers,
                             buffer_size=cfg.buffer_size)
    val_loader = GQNLoader(data_folder=cfg.data_folder,
                           mode='devel_val',
                           img_size=cfg.img_size,
                           val_frac=cfg.val_frac,
                           batch_size=cfg.batch_size,
                           num_workers=cfg.num_workers,
                           buffer_size=cfg.buffer_size)
    test_loader = GQNLoader(data_folder=cfg.data_folder,
                            mode='test',
                            img_size=cfg.img_size,
                            val_frac=cfg.val_frac,
                            batch_size=1,
                            num_workers=1,
                            buffer_size=cfg.buffer_size)
    # Create session to be used by loaders
    sess = tf.InteractiveSession()
    train_loader.sess = sess
    val_loader.sess = sess
    test_loader.sess = sess

    # Throughput stats
    if not cfg.debug:
        loader_throughput(train_loader)

    return (train_loader, val_loader, test_loader)
Ejemplo n.º 4
0
def loader_throughput(loader, num_batches=100, burn_in=5):
    assert num_batches > 0
    if burn_in is None:
        burn_in = num_batches // 10
    num_samples = 0
    fprint(f"Train loader throughput stats on {num_batches} batches...")
    for i, batch in enumerate(loader):
        if i == burn_in:
            timer = time.time()
        if i >= burn_in:
            num_samples += batch['input'].size(0)
        if i == num_batches + burn_in:
            break
    dt = time.time() - timer
    spb = dt / num_batches
    ips = num_samples / dt
    fprint(f"{spb:.3f} s/b, {ips:.1f} im/s")
Ejemplo n.º 5
0
def load(cfg, **unused_kwargs):
    # Fix TensorFlow seed
    global SEED
    SEED = cfg.seed

    if cfg.num_workers == 0:
        fprint("Need to use at least one worker for loading.")
        cfg.num_workers = 1

    del unused_kwargs
    print(f"Using {cfg.num_workers} data workers.")
    # Create data iterators
    train_loader = MineRLLoader(
        mode="devel_train",
        img_size=cfg.img_size,
        val_frac=cfg.val_frac,
        batch_size=cfg.batch_size,
        num_workers=cfg.num_workers,
        buffer_size=cfg.buffer_size,
    )
    val_loader = MineRLLoader(
        mode="devel_val",
        img_size=cfg.img_size,
        val_frac=cfg.val_frac,
        batch_size=cfg.batch_size,
        num_workers=cfg.num_workers,
        buffer_size=cfg.buffer_size,
    )
    test_loader = MineRLLoader(
        mode="test",
        img_size=cfg.img_size,
        val_frac=cfg.val_frac,
        batch_size=1,
        num_workers=1,
        buffer_size=cfg.buffer_size,
    )

    # Throughput stats
    loader_throughput(train_loader)

    return (train_loader, val_loader, test_loader)
Ejemplo n.º 6
0
def visualise_outputs(model, vis_batch, writer, mode, iter_idx):
    
    model.eval()
    
    # Only visualise for eight images
    # Forward pass
    vis_input = vis_batch['input'][:8]
    if next(model.parameters()).is_cuda:
        vis_input = vis_input.cuda()
    output, losses, stats, att_stats, comp_stats = model(vis_input)
    # Input and recon
    writer.add_image(mode+'_input', make_grid(vis_batch['input'][:8]), iter_idx)
    writer.add_image(mode+'_recon', make_grid(output), iter_idx)
    # Decomposition
    for key in ['mx_r_k', 'x_r_k', 'log_m_k', 'log_m_r_k']:
        if key not in stats:
            continue
        for step, val in enumerate(stats[key]):
            if 'log' in key:
                val = val.exp()
            writer.add_image(f'{mode}_{key}/k{step}', make_grid(val), iter_idx)
    
    # Generation
    try:
        output, stats = model.sample(batch_size=8, K_steps=model.K_steps)
        writer.add_image('samples', make_grid(output), iter_idx)
        for key in ['x_k', 'log_m_k', 'mx_k']:
            if key not in stats:
                continue
            for step, val in enumerate(stats[key]):
                if 'log' in key:
                    val = val.exp()
                writer.add_image(f'gen_{key}/k{step}', make_grid(val),
                                    iter_idx)
    except NotImplementedError:
        fprint("Sampling not implemented for this model.")
    
    model.train()
Ejemplo n.º 7
0
def dataset_ari(model, data_loader, num_images=1000):

    model.eval()

    fprint("Computing ARI on dataset")
    ari = []
    ari_fg = []
    model.eval()
    for bidx, batch in enumerate(data_loader):
        if next(model.parameters()).is_cuda:
            batch['input'] = batch['input'].cuda()
        with torch.no_grad():
            _, _, stats, _, _ = model(batch['input'])

        # Return zero if labels or segmentations are not available
        if 'instances' not in batch or not hasattr(stats, 'log_m_k'):
            return 0., 0., [0], [0]

        _, ari_list = average_ari(stats.log_m_k, batch['instances'])
        _, ari_fg_list = average_ari(stats.log_m_k, batch['instances'], True)
        ari += ari_list
        ari_fg += ari_fg_list
        if bidx % 1 == 0:
            log_ari = sum(ari) / len(ari)
            log_ari_fg = sum(ari_fg) / len(ari_fg)
            t = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
            fprint(f"{t} | After [{len(ari)} / {num_images}] images: " +
                   f"ARI {log_ari:.4f}, FG ARI {log_ari_fg:.4f}")
        if len(ari) >= num_images:
            break

    assert len(ari) == len(ari_fg)
    ari = ari[:num_images]
    ari_fg = ari_fg[:num_images]

    avg_ari = sum(ari) / len(ari)
    avg_ari_fg = sum(ari_fg) / len(ari_fg)
    fprint(f"FINAL ARI for {len(ari)} images: {avg_ari:.4f}")
    fprint(f"FINAL FG ARI for {len(ari_fg)} images: {avg_ari_fg:.4f}")

    model.train()

    return avg_ari, avg_ari_fg, ari_list, ari_fg_list
Ejemplo n.º 8
0
def fid_from_model(model,
                   test_loader,
                   batch_size=10,
                   num_images=10000,
                   feat_dim=2048,
                   img_dir='/tmp'):

    model.eval()

    # Save images from test set as pngs
    fprint("Saving images from test set as pngs.", True)
    test_dir = osp.join(img_dir, 'test_images')
    os.makedirs(test_dir)
    count = 0
    for bidx, batch in enumerate(test_loader):
        count = tensor_to_png(batch['input'], test_dir, count, num_images)
        if count >= num_images:
            break

    # Generate images and save as pngs
    fprint("Generate images and save as pngs.", True)
    gen_dir = osp.join(img_dir, 'generated_images')
    os.makedirs(gen_dir)
    count = 0
    for _ in tqdm(range(num_images // batch_size + 1)):
        if count >= num_images:
            break
        with torch.no_grad():
            gen_img, _ = model.sample(batch_size)
        count = tensor_to_png(gen_img, gen_dir, count, num_images)

    # Compute FID
    fprint("Computing FID.", True)
    gpu = next(model.parameters()).is_cuda
    fid_value = FID.calculate_fid_given_paths([test_dir, gen_dir], batch_size,
                                              gpu, feat_dim)
    fprint(f"FID: {fid_value}", True)

    model.train()

    return fid_value
Ejemplo n.º 9
0
def main():
    # Parse flags
    config = forge.config()
    fet.EXPERIMENT_FOLDER = config.model_dir
    fet.FPRINT_FILE = 'fid_evaluation.txt'
    config.shuffle_test = True

    # Fix seeds. Always first thing to be done after parsing the config!
    torch.manual_seed(config.seed)
    np.random.seed(config.seed)
    random.seed(config.seed)
    # Make CUDA operations deterministic
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Using GPU?
    if torch.cuda.is_available() and config.gpu:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
    else:
        config.gpu = False
        torch.set_default_tensor_type('torch.FloatTensor')
    fet.print_flags()

    # Load data
    _, _, test_loader = fet.load(config.data_config, config)

    #  Load model
    flag_path = osp.join(config.model_dir, 'flags.json')
    fprint(f"Restoring flags from {flag_path}")
    pretrained_flags = AttrDict(fet.json_load(flag_path))
    model = fet.load(config.model_config, pretrained_flags)
    model_path = osp.join(config.model_dir, config.model_file)
    fprint(f"Restoring model from {model_path}")
    checkpoint = torch.load(model_path, map_location='cpu')
    model_state_dict = checkpoint['model_state_dict']
    model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_1',
                         None)
    model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_2',
                         None)
    model.load_state_dict(model_state_dict)
    fprint(model)
    # Put model on GPU
    if config.gpu:
        model = model.cuda()

    # Compute FID
    fid_from_model(model, test_loader, config.batch_size,
                   config.num_fid_images, config.feat_dim, config.img_dir)
Ejemplo n.º 10
0
    def __next__(self):
        try:
            frame = self.sess.run(self.frames)
            self.count += 1

            # Parse image
            img = frame['image']
            img = np.moveaxis(img, 3, 1)
            shape = img.shape
            # TODO(martin): use more explicit CLEVR flag?
            if shape[2] != shape[3]:
                img = np_img_centre_crop(img, CLEVR_CROP, batch=True)
            img = torch.FloatTensor(img) / 255.
            if self.img_size != shape[2]:
                img = F.interpolate(img, size=self.img_size)

            # Parse masks
            raw_masks = frame['mask']
            masks = np.zeros((shape[0], 1, shape[2], shape[3]), dtype='int')
            # Convert to boolean masks
            cond = np.where(raw_masks[:, :, :, :, 0] == 255, True, False)
            # Ignore background entities
            num_entities = cond.shape[1]
            for o_idx in range(self.background_entities, num_entities):
                masks[cond[:, o_idx:o_idx + 1, :, :]] = o_idx + 1
            masks = torch.FloatTensor(masks)
            if shape[2] != shape[3]:
                masks = np_img_centre_crop(masks, CLEVR_CROP, batch=True)
            masks = torch.FloatTensor(masks)
            if self.img_size != shape[2]:
                masks = F.interpolate(masks, size=self.img_size)
            masks = masks.type(torch.LongTensor)

            return {'input': img, 'instances': masks}

        except tf.errors.OutOfRangeError:
            fprint("Reached end of epoch. Creating new iterator.")
            fprint(f"Counted {self.count} batches, expected {self.length}.")
            fprint("Creating new iterator.")
            self.count = 0
            raise StopIteration
Ejemplo n.º 11
0
def main():
    # ------------------------
    # SETUP
    # ------------------------

    # Parse flags
    config = forge.config()
    if config.debug:
        config.num_workers = 0
        config.batch_size = 2

    # Fix seeds. Always first thing to be done after parsing the config!
    torch.manual_seed(config.seed)
    np.random.seed(config.seed)
    # Make CUDA operations deterministic
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Setup checkpoint or resume
    logdir = osp.join(config.results_dir, config.run_name)
    logdir, resume_checkpoint = fet.init_checkpoint(logdir, config.data_config,
                                                    config.model_config,
                                                    config.resume)
    checkpoint_name = osp.join(logdir, 'model.ckpt')

    # Using GPU(S)?
    if torch.cuda.is_available() and config.gpu:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
    else:
        config.gpu = False
        torch.set_default_tensor_type('torch.FloatTensor')
    fprint(f"Use GPU: {config.gpu}")
    if config.gpu and config.multi_gpu and torch.cuda.device_count() > 1:
        fprint(f"Using {torch.cuda.device_count()} GPUs!")
        config.num_workers = torch.cuda.device_count() * config.num_workers
    else:
        config.multi_gpu = False

    # Print flags
    # fet.print_flags()
    # TODO(martin) make this cleaner
    fprint(json.dumps(fet._flags.FLAGS.__flags, indent=4, sort_keys=True))

    # Setup TensorboardX SummaryWriter
    writer = SummaryWriter(logdir)

    # Load data
    train_loader, val_loader, test_loader = fet.load(config.data_config,
                                                     config)
    num_elements = 3 * config.img_size**2  # Assume three input channels

    # Load model
    model = fet.load(config.model_config, config)
    fprint(model)
    if config.geco:
        # Goal is specified per pixel & channel so it doesn't need to
        # be changed for different resolutions etc.
        geco_goal = config.g_goal * num_elements
        # Scale step size to get similar update at different resolutions
        geco_lr = config.g_lr * (64**2 / config.img_size**2)
        geco = GECO(geco_goal, geco_lr, config.g_alpha, config.g_init,
                    config.g_min, config.g_speedup)
        beta = geco.beta
    else:
        beta = torch.tensor(config.beta)

    # Setup optimiser
    if config.optimiser == 'rmsprop':
        optimiser = optim.RMSprop(model.parameters(), config.learning_rate)
    elif config.optimiser == 'adam':
        optimiser = optim.Adam(model.parameters(), config.learning_rate)
    elif config.optimiser == 'sgd':
        optimiser = optim.SGD(model.parameters(), config.learning_rate, 0.9)

    # Try to restore model and optimiser from checkpoint
    iter_idx = 0
    if resume_checkpoint is not None:
        fprint(f"Restoring checkpoint from {resume_checkpoint}")
        checkpoint = torch.load(resume_checkpoint, map_location='cpu')
        # Restore model & optimiser
        model_state_dict = checkpoint['model_state_dict']
        model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_1',
                             None)
        model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_2',
                             None)
        model.load_state_dict(model_state_dict)
        optimiser.load_state_dict(checkpoint['optimiser_state_dict'])
        # Restore GECO
        if config.geco and 'beta' in checkpoint:
            geco.beta = checkpoint['beta']
        if config.geco and 'err_ema' in checkpoint:
            geco.err_ema = checkpoint['err_ema']
        # Update starting iter
        iter_idx = checkpoint['iter_idx'] + 1
    fprint(f"Starting training at iter = {iter_idx}")

    # Push model to GPU(s)?
    if config.multi_gpu:
        fprint("Wrapping model in DataParallel.")
        model = nn.DataParallel(model)
    if config.gpu:
        fprint("Pushing model to GPU.")
        model = model.cuda()
        if config.geco:
            geco.to_cuda()

    # ------------------------
    # TRAINING
    # ------------------------

    model.train()
    timer = time.time()
    while iter_idx <= config.train_iter:
        for train_batch in train_loader:
            # Parse data
            train_input = train_batch['input']
            if config.gpu:
                train_input = train_input.cuda()

            # Forward propagation
            optimiser.zero_grad()
            output, losses, stats, att_stats, comp_stats = model(train_input)

            # Reconstruction error
            err = losses.err.mean(0)
            # KL divergences
            kl_m, kl_l = torch.tensor(0), torch.tensor(0)
            # -- KL stage 1
            if 'kl_m' in losses:
                kl_m = losses.kl_m.mean(0)
            elif 'kl_m_k' in losses:
                kl_m = torch.stack(losses.kl_m_k, dim=1).mean(dim=0).sum()
            # -- KL stage 2
            if 'kl_l' in losses:
                kl_l = losses.kl_l.mean(0)
            elif 'kl_l_k' in losses:
                kl_l = torch.stack(losses.kl_l_k, dim=1).mean(dim=0).sum()

            # Compute ELBO
            elbo = (err + kl_l + kl_m).detach()
            err_new = err.detach()
            kl_new = (kl_m + kl_l).detach()
            # Compute MSE / RMSE
            mse_batched = ((train_input - output)**2).mean((1, 2, 3)).detach()
            rmse_batched = mse_batched.sqrt()
            mse, rmse = mse_batched.mean(0), rmse_batched.mean(0)

            # Main objective
            if config.geco:
                loss = geco.loss(err, kl_l + kl_m)
                beta = geco.beta
            else:
                if config.beta_warmup:
                    # Increase beta linearly over 20% of training
                    beta = config.beta * iter_idx / (0.2 * config.train_iter)
                    beta = torch.tensor(beta).clamp(0, config.beta)
                else:
                    beta = config.beta
                loss = err + beta * (kl_l + kl_m)

            # Backprop and optimise
            loss.backward()
            optimiser.step()

            # Heartbeat log
            if (iter_idx % config.report_loss_every == 0
                    or float(elbo) > ELBO_DIV or config.debug):
                # Print output and write to file
                ps = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
                ps += f' {config.run_name} | '
                ps += f'[{iter_idx}/{config.train_iter:.0e}]'
                ps += f' elb: {float(elbo):.0f} err: {float(err):.0f} '
                if 'kl_m' in losses or 'kl_m_k' in losses:
                    ps += f' klm: {float(kl_m):.1f}'
                if 'kl_l' in losses or 'kl_l_k' in losses:
                    ps += f' kll: {float(kl_l):.1f}'
                ps += f' bet: {float(beta):.1e}'
                s_per_b = (time.time() - timer)
                if not config.debug:
                    s_per_b /= config.report_loss_every
                timer = time.time()  # Reset timer
                ps += f' - {s_per_b:.2f} s/b'
                fprint(ps)

                # TensorBoard logging
                # -- Optimisation stats
                writer.add_scalar('optim/beta', beta, iter_idx)
                writer.add_scalar('optim/s_per_batch', s_per_b, iter_idx)
                if config.geco:
                    writer.add_scalar('optim/geco_err_ema', geco.err_ema,
                                      iter_idx)
                    writer.add_scalar('optim/geco_err_ema_element',
                                      geco.err_ema / num_elements, iter_idx)
                # -- Main loss terms
                writer.add_scalar('train/err', err, iter_idx)
                writer.add_scalar('train/err_element', err / num_elements,
                                  iter_idx)
                writer.add_scalar('train/kl_m', kl_m, iter_idx)
                writer.add_scalar('train/kl_l', kl_l, iter_idx)
                writer.add_scalar('train/elbo', elbo, iter_idx)
                writer.add_scalar('train/loss', loss, iter_idx)
                writer.add_scalar('train/mse', mse, iter_idx)
                writer.add_scalar('train/rmse', rmse, iter_idx)
                # -- Per step loss terms
                for key in ['kl_l_k', 'kl_m_k']:
                    if key not in losses: continue
                    for step, val in enumerate(losses[key]):
                        writer.add_scalar(f'train_steps/{key}{step}',
                                          val.mean(0), iter_idx)
                # -- Attention stats
                if config.log_distributions and att_stats is not None:
                    for key in ['mu_k', 'sigma_k', 'pmu_k', 'psigma_k']:
                        if key not in att_stats: continue
                        for step, val in enumerate(att_stats[key]):
                            writer.add_histogram(f'att_{key}_{step}', val,
                                                 iter_idx)
                # -- Component stats
                if config.log_distributions and comp_stats is not None:
                    for key in ['mu_k', 'sigma_k', 'pmu_k', 'psigma_k']:
                        if key not in comp_stats: continue
                        for step, val in enumerate(comp_stats[key]):
                            writer.add_histogram(f'comp_{key}_{step}', val,
                                                 iter_idx)

            # Save checkpoints
            ckpt_freq = config.train_iter / config.num_checkpoints
            if iter_idx % ckpt_freq == 0:
                ckpt_file = '{}-{}'.format(checkpoint_name, iter_idx)
                fprint(f"Saving model training checkpoint to: {ckpt_file}")
                if config.multi_gpu:
                    model_state_dict = model.module.state_dict()
                else:
                    model_state_dict = model.state_dict()
                ckpt_dict = {
                    'iter_idx': iter_idx,
                    'model_state_dict': model_state_dict,
                    'optimiser_state_dict': optimiser.state_dict(),
                    'elbo': elbo
                }
                if config.geco:
                    ckpt_dict['beta'] = geco.beta
                    ckpt_dict['err_ema'] = geco.err_ema
                torch.save(ckpt_dict, ckpt_file)

            # Run validation and log images
            if (iter_idx % config.run_validation_every == 0
                    or float(elbo) > ELBO_DIV):
                # Weight and gradient histograms
                if config.log_grads_and_weights:
                    for name, param in model.named_parameters():
                        writer.add_histogram(f'weights/{name}', param.data,
                                             iter_idx)
                        writer.add_histogram(f'grads/{name}', param.grad,
                                             iter_idx)
                # TensorboardX logging - images
                visualise_inference(model, train_batch, writer, 'train',
                                    iter_idx)
                # Validation
                fprint("Running validation...")
                eval_model = model.module if config.multi_gpu else model
                evaluation(eval_model,
                           val_loader,
                           writer,
                           config,
                           iter_idx,
                           N_eval=config.N_eval)

            # Increment counter
            iter_idx += 1
            if iter_idx > config.train_iter:
                break

            # Exit if training has diverged
            if elbo.item() > ELBO_DIV:
                fprint(f"ELBO: {elbo.item()}")
                fprint(
                    f"ELBO has exceeded {ELBO_DIV} - training has diverged.")
                sys.exit()

    # ------------------------
    # TESTING
    # ------------------------

    # Save final checkpoint
    ckpt_file = '{}-{}'.format(checkpoint_name, 'FINAL')
    fprint(f"Saving model training checkpoint to: {ckpt_file}")
    if config.multi_gpu:
        model_state_dict = model.module.state_dict()
    else:
        model_state_dict = model.state_dict()
    ckpt_dict = {
        'iter_idx': iter_idx,
        'model_state_dict': model_state_dict,
        'optimiser_state_dict': optimiser.state_dict()
    }
    if config.geco:
        ckpt_dict['beta'] = geco.beta
        ckpt_dict['err_ema'] = geco.err_ema
    torch.save(ckpt_dict, ckpt_file)

    # Test evaluation
    fprint("STARTING TESTING...")
    eval_model = model.module if config.gpu and config.multi_gpu else model
    final_elbo = evaluation(eval_model,
                            test_loader,
                            None,
                            config,
                            iter_idx,
                            N_eval=config.N_eval)
    fprint(f"TEST ELBO = {float(final_elbo)}")

    # FID computation
    try:
        fid_from_model(model, test_loader, img_dir=osp.join('/tmp', logdir))
    except NotImplementedError:
        fprint("Sampling not implemented for this model.")

    # Close writer
    writer.close()
Ejemplo n.º 12
0
def load(cfg, **unused_kwargs):
    # Fix TensorFlow seed
    global SEED
    SEED = cfg.seed
    tf.set_random_seed(SEED)

    del unused_kwargs
    fprint(f"Using {cfg.num_workers} data workers.")

    sess = tf.InteractiveSession()

    if cfg.dataset == 'multi_dsprites':
        cfg.img_size = 64 if cfg.img_size < 0 else cfg.img_size
        cfg.K_steps = 5 if cfg.K_steps < 0 else cfg.K_steps
        background_entities = 1
        max_frames = 60000
        raw_dataset = multi_dsprites.dataset(cfg.data_folder + MULTI_DSPRITES,
                                             'colored_on_colored',
                                             map_parallel_calls=cfg.num_workers
                                             if cfg.num_workers > 0 else None)
    elif cfg.dataset == 'objects_room':
        cfg.img_size = 64 if cfg.img_size < 0 else cfg.img_size
        cfg.K_steps = 7 if cfg.K_steps < 0 else cfg.K_steps
        background_entities = 4
        max_frames = 1000000
        raw_dataset = objects_room.dataset(cfg.data_folder + OBJECTS_ROOM,
                                           'train',
                                           map_parallel_calls=cfg.num_workers
                                           if cfg.num_workers > 0 else None)
    elif cfg.dataset == 'clevr':
        cfg.img_size = 128 if cfg.img_size < 0 else cfg.img_size
        cfg.K_steps = 11 if cfg.K_steps < 0 else cfg.K_steps
        background_entities = 1
        max_frames = 70000
        raw_dataset = clevr_with_masks.dataset(
            cfg.data_folder + CLEVR,
            map_parallel_calls=cfg.num_workers
            if cfg.num_workers > 0 else None)
    elif cfg.dataset == 'tetrominoes':
        cfg.img_size = 32 if cfg.img_size < 0 else cfg.img_size
        cfg.K_steps = 4 if cfg.K_steps < 0 else cfg.K_steps
        background_entities = 1
        max_frames = 60000
        raw_dataset = tetrominoes.dataset(cfg.data_folder + TETROMINOS,
                                          map_parallel_calls=cfg.num_workers
                                          if cfg.num_workers > 0 else None)
    else:
        raise NotImplementedError(f"{cfg.dataset} not a valid dataset.")

    # Split into train / val / test
    if cfg.dataset_size > max_frames:
        fprint(f"WARNING: {cfg.dataset_size} frames requested, "\
                "but only {max_frames} available.")
        cfg.dataset_size = max_frames
    if cfg.dataset_size > 0:
        total_sz = cfg.dataset_size
        raw_dataset = raw_dataset.take(total_sz)
    else:
        total_sz = max_frames
    if total_sz < 0:
        fprint("Determining size of dataset...")
        total_sz = len_tfrecords(raw_dataset, sess)
    fprint(f"Dataset has {total_sz} frames")

    val_sz = 10000
    tst_sz = 10000
    tng_sz = total_sz - val_sz - tst_sz
    assert tng_sz > 0
    fprint(f"Splitting into {tng_sz}/{val_sz}/{tst_sz} for tng/val/tst")
    tst_dataset = raw_dataset.take(tst_sz)
    val_dataset = raw_dataset.skip(tst_sz).take(val_sz)
    tng_dataset = raw_dataset.skip(tst_sz + val_sz)

    tng_loader = MultiOjectLoader(sess, tng_dataset, background_entities,
                                  tng_sz, cfg.batch_size, cfg.img_size,
                                  cfg.buffer_size)
    val_loader = MultiOjectLoader(sess, val_dataset, background_entities,
                                  val_sz, cfg.batch_size, cfg.img_size,
                                  cfg.buffer_size)
    tst_loader = MultiOjectLoader(sess, tst_dataset, background_entities,
                                  tst_sz, cfg.batch_size, cfg.img_size,
                                  cfg.buffer_size)

    # Throughput stats
    if not cfg.debug:
        loader_throughput(tng_loader)

    return (tng_loader, val_loader, tst_loader)
Ejemplo n.º 13
0
 def __iter__(self):
     fprint("Creating new one_shot_iterator.")
     it = self.dataset.make_one_shot_iterator()
     self.frames = it.get_next()
     return self
Ejemplo n.º 14
0
def main():
    # Parse flags
    config = forge.config()
    # Restore flags of pretrained model
    flag_path = osp.join(config.model_dir, 'flags.json')
    fprint(f"Restoring flags from {flag_path}")
    pretrained_flags = AttrDict(fet.json_load(flag_path))
    pretrained_flags.batch_size = 1
    pretrained_flags.gpu = False
    pretrained_flags.debug = True
    fet.print_flags()

    # Fix seeds. Always first thing to be done after parsing the config!
    torch.manual_seed(0)
    np.random.seed(0)
    random.seed(0)
    # Make CUDA operations deterministic
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    #  Load model
    model = fet.load(config.model_config, pretrained_flags)
    model_path = osp.join(config.model_dir, config.model_file)
    fprint(f"Restoring model from {model_path}")
    checkpoint = torch.load(model_path, map_location='cpu')
    model_state_dict = checkpoint['model_state_dict']
    model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_1',
                         None)
    model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_2',
                         None)
    model.load_state_dict(model_state_dict)
    fprint(model)

    # Visualise
    model.eval()
    for _ in range(100):
        y, stats = model.sample(1, pretrained_flags.K_steps)
        fig, axes = plt.subplots(nrows=4, ncols=1 + pretrained_flags.K_steps)

        # Generated
        plot(axes, 0, 0, y, title='Generated scene', fontsize=12)
        # Empty plots
        plot(axes, 1, 0, fontsize=12)
        plot(axes, 2, 0, fontsize=12)
        plot(axes, 3, 0, fontsize=12)

        # Put K generation steps in separate subfigures
        for step in range(pretrained_flags.K_steps):
            x_step = stats['x_k'][step]
            m_step = stats['log_m_k'][step].exp()
            mx_step = stats['mx_k'][step]
            if 'log_s_k' in stats:
                s_step = stats['log_s_k'][step].exp()
            pre = 'Mask x RGB ' if step == 0 else ''
            plot(axes, 0, 1 + step, mx_step, pre + f'k={step+1}', fontsize=12)
            pre = 'RGB ' if step == 0 else ''
            plot(axes, 1, 1 + step, x_step, pre + f'k={step+1}', fontsize=12)
            pre = 'Mask ' if step == 0 else ''
            plot(axes,
                 2,
                 1 + step,
                 m_step,
                 pre + f'k={step+1}',
                 True,
                 fontsize=12)
            if 'log_s_k' in stats:
                pre = 'Scope ' if step == 0 else ''
                plot(axes,
                     3,
                     1 + step,
                     s_step,
                     pre + f'k={step+1}',
                     True,
                     axis=step == 0,
                     fontsize=12)

        # Beautify and show figure
        plt.subplots_adjust(wspace=0.05, hspace=0.05)
        manager = plt.get_current_fig_manager()
        manager.resize(*manager.window.maxsize())
        plt.show()
Ejemplo n.º 15
0
def main():
    # Parse flags
    config = forge.config()
    config.batch_size = 1
    config.load_instances = True
    fet.print_flags()

    # Restore original model flags
    pretrained_flags = AttrDict(
        fet.json_load(os.path.join(config.model_dir, 'flags.json')))

    # Get validation loader
    train_loader, val_loader, test_loader = fet.load(config.data_config,
                                                     config)
    fprint(f"Split: {config.split}")
    if config.split == 'train':
        batch_loader = train_loader
    elif config.split == 'val':
        batch_loader = val_loader
    elif config.split == 'test':
        batch_loader = test_loader
    # Shuffle and prefetch to get same data for different models
    if 'gqn' not in config.data_config:
        batch_loader = torch.utils.data.DataLoader(batch_loader.dataset,
                                                   batch_size=1,
                                                   num_workers=0,
                                                   shuffle=True)
    # Prefetch batches
    prefetched_batches = []
    for i, x in enumerate(batch_loader):
        if i == config.num_images:
            break
        prefetched_batches.append(x)

    # Load model
    model = fet.load(config.model_config, pretrained_flags)
    fprint(model)
    model_path = os.path.join(config.model_dir, config.model_file)
    fprint(f"Restoring model from {model_path}")
    checkpoint = torch.load(model_path, map_location='cpu')
    model_state_dict = checkpoint['model_state_dict']
    model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_1',
                         None)
    model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_2',
                         None)
    model.load_state_dict(model_state_dict)

    # Set experiment folder and fprint file for logging
    fet.EXPERIMENT_FOLDER = config.model_dir
    fet.FPRINT_FILE = 'segmentation_metrics.txt'

    # Compute metrics
    model.eval()
    ari_fg_list, sc_fg_list, msc_fg_list = [], [], []
    with torch.no_grad():
        for i, x in enumerate(tqdm(prefetched_batches)):
            _, _, stats, _, _ = model(x['input'])
            # ARI
            ari_fg, _ = average_ari(stats.log_m_k,
                                    x['instances'],
                                    foreground_only=True)
            # Segmentation covering - foreground only
            gt_instances = x['instances'].clone()
            gt_instances[gt_instances == 0] = -100
            ins_preds = torch.argmax(torch.stack(stats.log_m_k, dim=1), dim=1)
            sc_fg = average_segcover(gt_instances, ins_preds)
            msc_fg = average_segcover(gt_instances, ins_preds, False)
            # Recording
            ari_fg_list.append(ari_fg)
            sc_fg_list.append(sc_fg)
            msc_fg_list.append(msc_fg)

    # Print average metrics
    fprint(f"Average FG ARI: {sum(ari_fg_list)/len(ari_fg_list)}")
    fprint(f"Average FG SegCover: {sum(sc_fg_list)/len(sc_fg_list)}")
    fprint(f"Average FG MeanSegCover: {sum(msc_fg_list)/len(msc_fg_list)}")
Ejemplo n.º 16
0
def evaluation(model, data_loader, writer, config, iter_idx, N_eval=None):
    # TODO(martin): make interface cleaner

    model.eval()

    t = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    if iter_idx == 0 or config.debug:
        num_batches = 1
    elif N_eval is not None and N_eval <= len(
            data_loader) * data_loader.batch_size:
        num_batches = N_eval // data_loader.batch_size
        fprint(t + f" | Evaluating only on first {N_eval} examples in loader")
    else:
        num_batches = len(data_loader)
        fprint(t + f" | Evaluating on all {num_batches} examples in loader")

    start_t = time.time()
    err, kl_l, kl_m, elbo = 0., 0., 0., 0.
    batch = None

    # Don't compute gradient to run faster
    with torch.no_grad():
        # Loop over loader
        for b_idx, batch in enumerate(data_loader):
            if config.gpu:
                batch['input'] = batch['input'].cuda()
            if b_idx == num_batches:
                fprint(f"Breaking from eval loop after {b_idx} batches")
                break

            if b_idx % 100 == 0:
                t = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
                fprint(t + f" | Validation batch [{b_idx+1} | {num_batches}]")

            _, losses, stats, _, _ = model(batch['input'])

            new_err = losses.err.mean(0)
            err += float(new_err) / num_batches
            # Parse different loss types
            if 'kl_m' in losses:
                new_kl_m = losses.kl_m.mean(0)
                kl_m += float(new_kl_m) / num_batches
            elif 'kl_m_k' in losses:
                new_kl_m = torch.stack(losses.kl_m_k, dim=1).sum(1).mean(0)
                kl_m += float(new_kl_m) / num_batches
            if 'kl_l' in losses:
                new_kl_l = losses.kl_l.mean(0)
                kl_l += float(new_kl_l) / num_batches
            elif 'kl_l_k' in losses:
                new_kl_l = torch.stack(losses.kl_l_k, dim=1).sum(1).mean(0)
                kl_l += float(new_kl_l) / num_batches
            # Update ELBO
            if 'elbo' not in losses:
                # Assign current "estimate"
                elbo += float(new_err + new_kl_l + new_kl_m) / num_batches
            else:
                # Add over steps
                elbo += float(losses.elbo.mean(0)) / num_batches

    # Printing
    duration = time.time() - start_t
    pstr = f'Evaluation elbo: {elbo:.1f}'
    pstr += f', err: {err:.1f}, kl_l: {kl_l:.1f}'
    pstr += f', kl_m: {kl_m:.1f}'
    pstr += f' --- {num_batches / duration:.1f} b/s'
    fprint(pstr)

    # TensorBoard logging
    if writer is not None:
        # TensorBoard logging - scalars
        writer.add_scalar('val/elbo', elbo, iter_idx)
        writer.add_scalar('val/err', err, iter_idx)
        writer.add_scalar('val/kl_l', kl_l, iter_idx)
        writer.add_scalar('val/kl_m', kl_m, iter_idx)
        # TensorBoard logging - inference (limit to 8)
        visualise_inference(model, batch, writer, 'val', iter_idx)
        # TensorBoard logging - generation (limit to 8)
        try:
            output, stats = model.sample(batch_size=8, K_steps=config.K_steps)
            writer.add_image('samples', make_grid(output), iter_idx)
            for key in ['x_k', 'log_m_k', 'mx_k']:
                if key not in stats:
                    continue
                for step, val in enumerate(stats[key]):
                    if 'log' in key:
                        val = val.exp()
                    writer.add_image(f'gen_{key}/k{step}', make_grid(val),
                                     iter_idx)
        except NotImplementedError:
            fprint("Sampling not implemented for this model.")

    model.train()

    return elbo
Ejemplo n.º 17
0
def evaluation(model, data_loader, writer, config, iter_idx,
               N_eval=None, N_seg_metrics=50):

    model.eval()
    torch.set_grad_enabled(False)

    batch_size = data_loader.batch_size

    if iter_idx == 0 or config.debug:
        num_batches = 1
        fprint(f"ITER 0 / DEBUG - eval on {num_batches} batches", True)
    elif N_eval is not None and N_eval <= len(data_loader)*batch_size:
        num_batches = int(N_eval // batch_size)
        fprint(f"N_eval = {N_eval}, eval on {num_batches} batches", True)
    else:
        num_batches = len(data_loader)
        fprint(f"Eval on all {num_batches} batches")

    start_t = time.time()
    eval_stats = AttrDefault(list, {})
    batch = None

    # Loop over loader
    for b_idx, batch in enumerate(data_loader):
        if b_idx == num_batches:
            fprint(f"Breaking from eval loop after {b_idx} batches")
            break

        if config.gpu:
            for key, val in batch.items():
                batch[key] = val.cuda()

        # Forward pass
        _, losses, stats, _, _ = model(batch['input'])

        # Track individual loss terms
        for key, val in losses.items():
            # Sum over steps if needed
            if isinstance(val, list):
                eval_stats[key].append(torch.stack(val, 1).sum(1).mean(0))
            else:
                eval_stats[key].append(val.mean(0))

        # Track ELBO
        kl_m, kl_l = torch.tensor(0), torch.tensor(0)
        if 'kl_m_k' in losses:
            kl_m = torch.stack(losses.kl_m_k, dim=1).sum(1).mean(0)
        elif 'kl_m' in losses:
            kl_m = losses.kl_m.mean(0)
        if 'kl_l_k' in losses:
            kl_l = torch.stack(losses.kl_l_k, dim=1).sum(1).mean(0)
        elif 'kl_l' in losses:
            kl_l = losses.kl_l.mean(0)
        eval_stats['elbo'].append(losses.err.mean(0) + kl_m + kl_l)

        # Track segmentation metrics metrics
        if      ('instances' in batch and 'log_m_k' in stats and
                 b_idx*batch_size < N_seg_metrics):
            # ARI
            new_ari, _ = average_ari(
                stats.log_m_k, batch['instances'])
            new_ari_fg, _ = average_ari(
                stats.log_m_k, batch['instances'], True)
            eval_stats['ari'].append(new_ari)
            eval_stats['ari_fg'].append(new_ari_fg)
            # Segmentation Covering
            iseg = torch.argmax(torch.cat(stats.log_m_k, 1), 1, True)
            msc, _ = average_segcover(batch['instances'], iseg)
            msc_fg, _ = average_segcover(batch['instances'], iseg,
                                         ignore_background=True)
            eval_stats['msc'].append(msc)
            eval_stats['msc_fg'].append(msc_fg)

    # Sum over batches
    for key, val in eval_stats.items():
        # Sanity check
        if ('ari' in key or 'msc' in key) and not config.debug and iter_idx > 0:
            assert len(val)*batch_size >= N_seg_metrics
            assert len(val)*batch_size < N_seg_metrics+batch_size
        eval_stats[key] = sum(val) / len(val)

    # Track element-wise error
    nelements = np.prod(batch['input'].shape[1:4])
    eval_stats['err_element'] = eval_stats['err'] / nelements

    # Printing
    duration = time.time() - start_t
    fprint(f'Eval duration: {duration:.1f}s, {num_batches / duration:.1f} b/s')
    eval_stats['duration'] = duration
    eval_stats['num_batches'] = num_batches
    eval_stats = dict(eval_stats)
    for key, val in eval_stats.items():
        eval_stats[key] = float(val)

    # TensorBoard logging
    if writer is not None:
        log_scalars(eval_stats, 'val', iter_idx, writer)

    model.train()
    torch.set_grad_enabled(True)

    return eval_stats
Ejemplo n.º 18
0
def main():
    # Parse flags
    config = forge.config()
    fet.print_flags()
    # Restore flags of pretrained model
    flag_path = osp.join(config.model_dir, 'flags.json')
    fprint(f"Restoring flags from {flag_path}")
    pretrained_flags = AttrDict(fet.json_load(flag_path))
    pretrained_flags.debug = True

    # Fix seeds. Always first thing to be done after parsing the config!
    torch.manual_seed(0)
    np.random.seed(0)
    random.seed(0)
    # Make CUDA operations deterministic
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Load data
    config.batch_size = 1
    _, _, test_loader = fet.load(config.data_config, config)

    # Load model
    model = fet.load(config.model_config, pretrained_flags)
    model_path = osp.join(config.model_dir, config.model_file)
    fprint(f"Restoring model from {model_path}")
    checkpoint = torch.load(model_path, map_location='cpu')
    model_state_dict = checkpoint['model_state_dict']
    model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_1',
                         None)
    model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_2',
                         None)
    model.load_state_dict(model_state_dict)
    fprint(model)

    # Visualise
    model.eval()
    for count, batch in enumerate(test_loader):
        if count >= config.num_images:
            break

        # Forward pass
        output, _, stats, _, _ = model(batch['input'])
        # Set up figure
        fig, axes = plt.subplots(nrows=4, ncols=1 + pretrained_flags.K_steps)

        # Input and reconstruction
        plot(axes, 0, 0, batch['input'], title='Input image', fontsize=12)
        plot(axes, 1, 0, output, title='Reconstruction', fontsize=12)
        # Empty plots
        plot(axes, 2, 0, fontsize=12)
        plot(axes, 3, 0, fontsize=12)

        # Put K reconstruction steps into separate subfigures
        x_k = stats['x_r_k']
        log_m_k = stats['log_m_k']
        mx_k = [x * m.exp() for x, m in zip(x_k, log_m_k)]
        log_s_k = stats['log_s_k'] if 'log_s_k' in stats else None
        for step in range(pretrained_flags.K_steps):
            mx_step = mx_k[step]
            x_step = x_k[step]
            m_step = log_m_k[step].exp()
            if log_s_k:
                s_step = log_s_k[step].exp()

            pre = 'Mask x RGB ' if step == 0 else ''
            plot(axes, 0, 1 + step, mx_step, pre + f'k={step+1}', fontsize=12)
            pre = 'RGB ' if step == 0 else ''
            plot(axes, 1, 1 + step, x_step, pre + f'k={step+1}', fontsize=12)
            pre = 'Mask ' if step == 0 else ''
            plot(axes,
                 2,
                 1 + step,
                 m_step,
                 pre + f'k={step+1}',
                 True,
                 fontsize=12)
            if log_s_k:
                pre = 'Scope ' if step == 0 else ''
                plot(axes,
                     3,
                     1 + step,
                     s_step,
                     pre + f'k={step+1}',
                     True,
                     axis=step == 0,
                     fontsize=12)

        # Beautify and show figure
        plt.subplots_adjust(wspace=0.05, hspace=0.15)
        manager = plt.get_current_fig_manager()
        manager.resize(*manager.window.maxsize())
        plt.show()