예제 #1
0
def main(config):
    random.seed(config['seed'])
    torch.manual_seed(config['seed'])
    torch.cuda.manual_seed_all(config['seed'])

    results_dir = prepare_results_dir(config)
    starting_epoch = find_latest_epoch(results_dir) + 1

    if not exists(join(results_dir, 'config.json')):
        with open(join(results_dir, 'config.json'), mode='w') as f:
            json.dump(config, f)

    setup_logging(results_dir)
    log = logging.getLogger(__name__)

    device = cuda_setup(config['cuda'], config['gpu'])
    log.debug(f'Device variable: {device}')
    if device.type == 'cuda':
        log.debug(f'Current CUDA device: {torch.cuda.current_device()}')

    weights_path = join(results_dir, 'weights')

    #
    # Dataset
    #
    dataset_name = config['dataset'].lower()
    if dataset_name == 'shapenet':
        from datasets.shapenet import ShapeNetDataset
        dataset = ShapeNetDataset(root_dir=config['data_dir'],
                                  classes=config['classes'])
    else:
        raise ValueError(f'Invalid dataset name. Expected `shapenet` or '
                         f'`faust`. Got: `{dataset_name}`')
    log.debug("Selected {} classes. Loaded {} samples.".format(
        'all' if not config['classes'] else ','.join(config['classes']),
        len(dataset)))

    points_dataloader = DataLoader(dataset, batch_size=config['batch_size'],
                                   shuffle=config['shuffle'],
                                   num_workers=config['num_workers'],
                                   drop_last=True, pin_memory=True)

    #
    # Models
    #
    arch = import_module(f"models.{config['arch']}")
    G = arch.Generator(config).to(device)
    E = arch.Encoder(config).to(device)
    D = arch.Discriminator(config).to(device)

    G.apply(weights_init)
    E.apply(weights_init)
    D.apply(weights_init)

    if config['reconstruction_loss'].lower() == 'chamfer':
        from losses.champfer_loss import ChamferLoss
        reconstruction_loss = ChamferLoss().to(device)
    elif config['reconstruction_loss'].lower() == 'earth_mover':
        from losses.earth_mover_distance import EMD
        reconstruction_loss = EMD().to(device)
    else:
        raise ValueError(f'Invalid reconstruction loss. Accepted `chamfer` or '
                         f'`earth_mover`, got: {config["reconstruction_loss"]}')
    #
    # Float Tensors
    #
    distribution = config['distribution'].lower()
    if distribution == 'bernoulli':
        p = torch.tensor(config['p']).to(device)
        sampler = Bernoulli(probs=p)
        fixed_noise = sampler.sample(torch.Size([config['batch_size'],
                                                 config['z_size']]))
    elif distribution == 'beta':
        fixed_noise_np = np.random.beta(config['z_beta_a'],
                                        config['z_beta_b'],
                                        size=(config['batch_size'],
                                              config['z_size']))
        fixed_noise = torch.tensor(fixed_noise_np).float().to(device)
    else:
        raise ValueError('Invalid distribution for binaray model.')

    #
    # Optimizers
    #
    EG_optim = getattr(optim, config['optimizer']['EG']['type'])
    EG_optim = EG_optim(chain(E.parameters(), G.parameters()),
                        **config['optimizer']['EG']['hyperparams'])

    D_optim = getattr(optim, config['optimizer']['D']['type'])
    D_optim = D_optim(D.parameters(),
                      **config['optimizer']['D']['hyperparams'])

    if starting_epoch > 1:
        G.load_state_dict(torch.load(
            join(weights_path, f'{starting_epoch-1:05}_G.pth')))
        E.load_state_dict(torch.load(
            join(weights_path, f'{starting_epoch-1:05}_E.pth')))
        D.load_state_dict(torch.load(
            join(weights_path, f'{starting_epoch-1:05}_D.pth')))

        D_optim.load_state_dict(torch.load(
            join(weights_path, f'{starting_epoch-1:05}_Do.pth')))

        EG_optim.load_state_dict(torch.load(
            join(weights_path, f'{starting_epoch-1:05}_EGo.pth')))

    loss_d_tot, loss_gp_tot, loss_e_tot, loss_g_tot = [], [], [], []
    for epoch in range(starting_epoch, config['max_epochs'] + 1):
        start_epoch_time = datetime.now()

        G.train()
        E.train()
        D.train()

        total_loss_eg = 0.0
        total_loss_d = 0.0
        for i, point_data in enumerate(points_dataloader, 1):
            log.debug('-' * 20)

            X, _ = point_data
            X = X.to(device)

            # Change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS]
            if X.size(-1) == 3:
                X.transpose_(X.dim() - 2, X.dim() - 1)

            codes = E(X)
            if distribution == 'bernoulli':
                noise = sampler.sample(fixed_noise.size())
            elif distribution == 'beta':
                noise_np = np.random.beta(config['z_beta_a'],
                                          config['z_beta_b'],
                                          size=(config['batch_size'],
                                                config['z_size']))
                noise = torch.tensor(noise_np).float().to(device)
            synth_logit = D(codes)
            real_logit = D(noise)
            loss_d = torch.mean(synth_logit) - torch.mean(real_logit)
            loss_d_tot.append(loss_d)

            # Gradient Penalty
            alpha = torch.rand(config['batch_size'], 1).to(device)
            differences = codes - noise
            interpolates = noise + alpha * differences
            disc_interpolates = D(interpolates)

            gradients = grad(
                outputs=disc_interpolates,
                inputs=interpolates,
                grad_outputs=torch.ones_like(disc_interpolates).to(device),
                create_graph=True,
                retain_graph=True,
                only_inputs=True)[0]
            slopes = torch.sqrt(torch.sum(gradients ** 2, dim=1))
            gradient_penalty = ((slopes - 1) ** 2).mean()
            loss_gp = config['gp_lambda'] * gradient_penalty
            loss_gp_tot.append(loss_gp)
            ###

            loss_d += loss_gp

            D_optim.zero_grad()
            D.zero_grad()

            loss_d.backward(retain_graph=True)
            total_loss_d += loss_d.item()
            D_optim.step()

            # EG part of training
            X_rec = G(codes)

            loss_e = torch.mean(
                config['reconstruction_coef'] *
                reconstruction_loss(X.permute(0, 2, 1) + 0.5,
                                    X_rec.permute(0, 2, 1) + 0.5))
            loss_e_tot.append(loss_e)

            synth_logit = D(codes)

            loss_g = -torch.mean(synth_logit)
            loss_g_tot.append(loss_g)

            loss_eg = loss_e + loss_g
            EG_optim.zero_grad()
            E.zero_grad()
            G.zero_grad()

            loss_eg.backward()
            total_loss_eg += loss_eg.item()
            EG_optim.step()

            log.debug(f'[{epoch}: ({i})] '
                      f'Loss_D: {loss_d.item():.4f} '
                      f'(GP: {loss_gp.item(): .4f}) '
                      f'Loss_EG: {loss_eg.item():.4f} '
                      f'(REC: {loss_e.item(): .4f}) '
                      f'Time: {datetime.now() - start_epoch_time}')

        log.debug(
            f'[{epoch}/{config["max_epochs"]}] '
            f'Loss_D: {total_loss_d / i:.4f} '
            f'Loss_EG: {total_loss_eg / i:.4f} '
            f'Time: {datetime.now() - start_epoch_time}'
        )

        #
        # Save intermediate results
        #
        G.eval()
        E.eval()
        D.eval()
        with torch.no_grad():
            fake = G(fixed_noise).data.cpu().numpy()
            X_rec = G(E(X)).data.cpu().numpy()
            X = X.data.cpu().numpy()

        plt.figure(figsize=(16, 9))
        plt.plot(loss_d_tot, 'r-', label="loss_d")
        plt.plot(loss_gp_tot, 'g-', label="loss_gp")
        plt.plot(loss_e_tot, 'b-', label="loss_e")
        plt.plot(loss_g_tot, 'k-', label="loss_g")
        plt.legend()
        plt.xlabel("Batch number")
        plt.xlabel("Loss value")
        plt.savefig(
            join(results_dir, 'samples', f'loss_plot.png'))
        plt.close()

        for k in range(5):
            fig = plot_3d_point_cloud(X[k][0], X[k][1], X[k][2],
                                      in_u_sphere=True, show=False,
                                      title=str(epoch))
            fig.savefig(
                join(results_dir, 'samples', f'{epoch:05}_{k}_real.png'))
            plt.close(fig)

        for k in range(5):
            fig = plot_3d_point_cloud(fake[k][0], fake[k][1], fake[k][2],
                                      in_u_sphere=True, show=False,
                                      title=str(epoch))
            fig.savefig(
                join(results_dir, 'samples', f'{epoch:05}_{k}_fixed.png'))
            plt.close(fig)

        for k in range(5):
            fig = plot_3d_point_cloud(X_rec[k][0], X_rec[k][1], X_rec[k][2],
                                      in_u_sphere=True, show=False,
                                      title=str(epoch))
            fig.savefig(join(results_dir, 'samples',
                             f'{epoch:05}_{k}_reconstructed.png'))
            plt.close(fig)

        if epoch % config['save_frequency'] == 0:
            torch.save(G.state_dict(), join(weights_path, f'{epoch:05}_G.pth'))
            torch.save(D.state_dict(), join(weights_path, f'{epoch:05}_D.pth'))
            torch.save(E.state_dict(), join(weights_path, f'{epoch:05}_E.pth'))

            torch.save(EG_optim.state_dict(),
                       join(weights_path, f'{epoch:05}_EGo.pth'))

            torch.save(D_optim.state_dict(),
                       join(weights_path, f'{epoch:05}_Do.pth'))
def main(config):
    set_seed(config['seed'])

    results_dir = prepare_results_dir(config,
                                      config['arch'],
                                      'experiments',
                                      dirs_to_create=[
                                          'interpolations', 'sphere',
                                          'points_interpolation',
                                          'different_number_points', 'fixed',
                                          'reconstruction', 'sphere_triangles',
                                          'sphere_triangles_interpolation'
                                      ])
    weights_path = get_weights_dir(config)
    epoch = find_latest_epoch(weights_path)

    if not epoch:
        print("Invalid 'weights_path' in configuration")
        exit(1)

    setup_logging(results_dir)
    global log
    log = logging.getLogger('aae')

    if not exists(join(results_dir, 'experiment_config.json')):
        with open(join(results_dir, 'experiment_config.json'), mode='w') as f:
            json.dump(config, f)

    device = cuda_setup(config['cuda'], config['gpu'])
    log.info(f'Device variable: {device}')
    if device.type == 'cuda':
        log.info(f'Current CUDA device: {torch.cuda.current_device()}')

    #
    # Dataset
    #
    dataset_name = config['dataset'].lower()
    if dataset_name == 'shapenet':
        from datasets.shapenet import ShapeNetDataset
        dataset = ShapeNetDataset(root_dir=config['data_dir'],
                                  classes=config['classes'])
    elif dataset_name == 'custom':
        dataset = TxtDataset(root_dir=config['data_dir'],
                             classes=config['classes'],
                             config=config)
    elif dataset_name == 'benchmark':
        dataset = Benchmark(root_dir=config['data_dir'],
                            classes=config['classes'],
                            config=config)
    else:
        raise ValueError(f'Invalid dataset name. Expected `shapenet` or '
                         f'`faust`. Got: `{dataset_name}`')

    log.info("Selected {} classes. Loaded {} samples.".format(
        'all' if not config['classes'] else ','.join(config['classes']),
        len(dataset)))

    points_dataloader = DataLoader(dataset,
                                   batch_size=64,
                                   shuffle=True,
                                   num_workers=8,
                                   drop_last=True,
                                   pin_memory=True,
                                   collate_fn=collate_fn)

    #
    # Models
    #
    hyper_network = aae.HyperNetwork(config, device).to(device)
    encoder_visible = aae.VisibleEncoder(config).to(device)
    encoder_pocket = aae.PocketEncoder(config).to(device)

    if config['reconstruction_loss'].lower() == 'chamfer':
        from losses.champfer_loss import ChamferLoss
        reconstruction_loss = ChamferLoss().to(device)
    elif config['reconstruction_loss'].lower() == 'earth_mover':
        # from utils.metrics import earth_mover_distance
        # reconstruction_loss = earth_mover_distance
        from losses.earth_mover_distance import EMD
        reconstruction_loss = EMD().to(device)
    else:
        raise ValueError(
            f'Invalid reconstruction loss. Accepted `chamfer` or '
            f'`earth_mover`, got: {config["reconstruction_loss"]}')

    log.info("Weights for epoch: %s" % epoch)

    log.info("Loading weights...")
    hyper_network.load_state_dict(
        torch.load(join(weights_path, f'{epoch:05}_G.pth')))
    encoder_pocket.load_state_dict(
        torch.load(join(weights_path, f'{epoch:05}_EP.pth')))
    encoder_visible.load_state_dict(
        torch.load(join(weights_path, f'{epoch:05}_EV.pth')))

    hyper_network.eval()
    encoder_visible.eval()
    encoder_pocket.eval()

    total_loss_eg = 0.0
    total_loss_e = 0.0
    total_loss_kld = 0.0
    x = []

    with torch.no_grad():
        for i, point_data in enumerate(points_dataloader, 1):
            X = point_data['non-visible']
            X = X.to(device, dtype=torch.float)

            # get whole point cloud
            X_whole = point_data['cloud']
            X_whole = X_whole.to(device, dtype=torch.float)

            # get visible point cloud
            X_visible = point_data['visible']
            X_visible = X_visible.to(device, dtype=torch.float)

            # Change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS]
            if X.size(-1) == 3:
                X.transpose_(X.dim() - 2, X.dim() - 1)
                X_whole.transpose_(X_whole.dim() - 2, X_whole.dim() - 1)
                X_visible.transpose_(X_visible.dim() - 2, X_visible.dim() - 1)

            x.append(X)
            codes, mu, logvar = encoder_pocket(X)
            mu_visible = encoder_visible(X_visible)
            target_networks_weights = hyper_network(
                torch.cat((codes, mu_visible), 1))

            X_rec = torch.zeros(X_whole.shape).to(device)
            for j, target_network_weights in enumerate(
                    target_networks_weights):
                target_network = aae.TargetNetwork(
                    config, target_network_weights).to(device)
                target_network_input = generate_points(config=config,
                                                       epoch=epoch,
                                                       size=(X_whole.shape[2],
                                                             X_whole.shape[1]))
                X_rec[j] = torch.transpose(
                    target_network(target_network_input.to(device)), 0, 1)

            loss_e = torch.mean(config['reconstruction_coef'] *
                                reconstruction_loss(
                                    X_whole.permute(0, 2, 1) + 0.5,
                                    X_rec.permute(0, 2, 1) + 0.5))

            loss_kld = 0.5 * (torch.exp(logvar) + torch.pow(mu, 2) - 1 -
                              logvar).sum()

            loss_eg = loss_e + loss_kld
            total_loss_e += loss_e.item()
            total_loss_kld += loss_kld.item()
            total_loss_eg += loss_eg.item()

        log.info(f'Loss_ALL: {total_loss_eg / i:.4f} '
                 f'Loss_R: {total_loss_e / i:.4f} '
                 f'Loss_E: {total_loss_kld / i:.4f} ')

        # take the lowest possible first dim
        min_dim = min(x, key=lambda X: X.shape[2]).shape[2]
        x = [X[:, :, :min_dim] for X in x]
        x = torch.cat(x)

        if config['experiments']['interpolation']['execute']:
            interpolation(
                x, encoder_pocket, hyper_network, device, results_dir, epoch,
                config['experiments']['interpolation']['amount'],
                config['experiments']['interpolation']['transitions'])

        if config['experiments']['interpolation_between_two_points'][
                'execute']:
            interpolation_between_two_points(
                encoder_pocket, hyper_network, device, x, results_dir, epoch,
                config['experiments']['interpolation_between_two_points']
                ['amount'], config['experiments']
                ['interpolation_between_two_points']['image_points'],
                config['experiments']['interpolation_between_two_points']
                ['transitions'])

        if config['experiments']['reconstruction']['execute']:
            reconstruction(encoder_pocket, hyper_network, device, x,
                           results_dir, epoch,
                           config['experiments']['reconstruction']['amount'])

        if config['experiments']['sphere']['execute']:
            sphere(encoder_pocket, hyper_network, device, x, results_dir,
                   epoch, config['experiments']['sphere']['amount'],
                   config['experiments']['sphere']['image_points'],
                   config['experiments']['sphere']['start'],
                   config['experiments']['sphere']['end'],
                   config['experiments']['sphere']['transitions'])

        if config['experiments']['sphere_triangles']['execute']:
            sphere_triangles(
                encoder_pocket, hyper_network, device, x, results_dir,
                config['experiments']['sphere_triangles']['amount'],
                config['experiments']['sphere_triangles']['method'],
                config['experiments']['sphere_triangles']['depth'],
                config['experiments']['sphere_triangles']['start'],
                config['experiments']['sphere_triangles']['end'],
                config['experiments']['sphere_triangles']['transitions'])

        if config['experiments']['sphere_triangles_interpolation']['execute']:
            sphere_triangles_interpolation(
                encoder_pocket, hyper_network, device, x, results_dir,
                config['experiments']['sphere_triangles_interpolation']
                ['amount'], config['experiments']
                ['sphere_triangles_interpolation']['method'],
                config['experiments']['sphere_triangles_interpolation']
                ['depth'], config['experiments']
                ['sphere_triangles_interpolation']['coefficient'],
                config['experiments']['sphere_triangles_interpolation']
                ['transitions'])

        if config['experiments']['different_number_of_points']['execute']:
            different_number_of_points(
                encoder_pocket, hyper_network, x, device, results_dir, epoch,
                config['experiments']['different_number_of_points']['amount'],
                config['experiments']['different_number_of_points']
                ['image_points'])

        if config['experiments']['fixed']['execute']:
            # get visible element from loader (probably should be done using given object for example using
            # parser

            points_dataloader = DataLoader(dataset,
                                           batch_size=10,
                                           shuffle=True,
                                           num_workers=8,
                                           drop_last=True,
                                           pin_memory=True,
                                           collate_fn=collate_fn)
            X_visible = next(iter(points_dataloader))['visible'].to(
                device, dtype=torch.float)
            X_visible.transpose_(X_visible.dim() - 2, X_visible.dim() - 1)

            fixed(hyper_network, encoder_visible, X_visible, device,
                  results_dir, epoch, config['experiments']['fixed']['amount'],
                  config['z_size'] // 2,
                  config['experiments']['fixed']['mean'],
                  config['experiments']['fixed']['std'], (3, 2048),
                  config['experiments']['fixed']['triangulation']['execute'],
                  config['experiments']['fixed']['triangulation']['method'],
                  config['experiments']['fixed']['triangulation']['depth'])
예제 #3
0
def main(config):
    set_seed(config['seed'])

    results_dir = prepare_results_dir(config, 'aae', 'training')
    starting_epoch = find_latest_epoch(results_dir) + 1

    if not exists(join(results_dir, 'config.json')):
        with open(join(results_dir, 'config.json'), mode='w') as f:
            json.dump(config, f)

    setup_logging(results_dir)
    log = logging.getLogger('aae')

    device = cuda_setup(config['cuda'], config['gpu'])
    log.info(f'Device variable: {device}')
    if device.type == 'cuda':
        log.info(f'Current CUDA device: {torch.cuda.current_device()}')

    weights_path = join(results_dir, 'weights')
    metrics_path = join(results_dir, 'metrics')

    #
    # Dataset
    #
    dataset_name = config['dataset'].lower()
    if dataset_name == 'shapenet':
        from datasets.shapenet import ShapeNetDataset
        dataset = ShapeNetDataset(root_dir=config['data_dir'],
                                  classes=config['classes'])
    else:
        raise ValueError(f'Invalid dataset name. Expected `shapenet` or '
                         f'`faust`. Got: `{dataset_name}`')

    log.info("Selected {} classes. Loaded {} samples.".format(
        'all' if not config['classes'] else ','.join(config['classes']),
        len(dataset)))

    points_dataloader = DataLoader(dataset, batch_size=config['batch_size'],
                                   shuffle=config['shuffle'],
                                   num_workers=config['num_workers'],
                                   pin_memory=True)

    pointnet = config.get('pointnet', False)
    #
    # Models
    #
    hyper_network = aae.HyperNetwork(config, device).to(device)

    if pointnet:
        from models.pointnet import PointNet
        encoder = PointNet(config).to(device)
        # PointNet initializes it's own weights during instance creation
    else:
        encoder = aae.Encoder(config).to(device)
        encoder.apply(weights_init)

    discriminator = aae.Discriminator(config).to(device)

    hyper_network.apply(weights_init)
    discriminator.apply(weights_init)

    if config['reconstruction_loss'].lower() == 'chamfer':
        if pointnet:
            from utils.metrics import chamfer_distance
            reconstruction_loss = chamfer_distance
        else:
            from losses.champfer_loss import ChamferLoss
            reconstruction_loss = ChamferLoss().to(device)
    elif config['reconstruction_loss'].lower() == 'earth_mover':
        from utils.metrics import earth_mover_distance
        reconstruction_loss = earth_mover_distance
    else:
        raise ValueError(f'Invalid reconstruction loss. Accepted `chamfer` or '
                         f'`earth_mover`, got: {config["reconstruction_loss"]}')

    #
    # Optimizers
    #
    e_hn_optimizer = getattr(optim, config['optimizer']['E_HN']['type'])
    e_hn_optimizer = e_hn_optimizer(chain(encoder.parameters(), hyper_network.parameters()),
                                    **config['optimizer']['E_HN']['hyperparams'])

    discriminator_optimizer = getattr(optim, config['optimizer']['D']['type'])
    discriminator_optimizer = discriminator_optimizer(discriminator.parameters(),
                                                      **config['optimizer']['D']['hyperparams'])

    log.info("Starting epoch: %s" % starting_epoch)
    if starting_epoch > 1:
        log.info("Loading weights...")
        hyper_network.load_state_dict(torch.load(
            join(weights_path, f'{starting_epoch - 1:05}_G.pth')))
        encoder.load_state_dict(torch.load(
            join(weights_path, f'{starting_epoch - 1:05}_E.pth')))
        discriminator.load_state_dict(torch.load(
            join(weights_path, f'{starting_epoch-1:05}_D.pth')))

        e_hn_optimizer.load_state_dict(torch.load(
            join(weights_path, f'{starting_epoch - 1:05}_EGo.pth')))

        discriminator_optimizer.load_state_dict(torch.load(
            join(weights_path, f'{starting_epoch-1:05}_Do.pth')))

        log.info("Loading losses...")
        losses_e = np.load(join(metrics_path, f'{starting_epoch - 1:05}_E.npy')).tolist()
        losses_g = np.load(join(metrics_path, f'{starting_epoch - 1:05}_G.npy')).tolist()
        losses_eg = np.load(join(metrics_path, f'{starting_epoch - 1:05}_EG.npy')).tolist()
        losses_d = np.load(join(metrics_path, f'{starting_epoch - 1:05}_D.npy')).tolist()
    else:
        log.info("First epoch")
        losses_e = []
        losses_g = []
        losses_eg = []
        losses_d = []

    normalize_points = config['target_network_input']['normalization']['enable']
    if normalize_points:
        normalization_type = config['target_network_input']['normalization']['type']
        assert normalization_type == 'progressive', 'Invalid normalization type'

    target_network_input = None
    for epoch in range(starting_epoch, config['max_epochs'] + 1):
        start_epoch_time = datetime.now()
        log.debug("Epoch: %s" % epoch)
        hyper_network.train()
        encoder.train()
        discriminator.train()

        total_loss_all = 0.0
        total_loss_reconstruction = 0.0
        total_loss_encoder = 0.0
        total_loss_discriminator = 0.0
        total_loss_regularization = 0.0
        for i, point_data in enumerate(points_dataloader, 1):

            X, _ = point_data
            X = X.to(device)

            # Change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS]
            if X.size(-1) == 3:
                X.transpose_(X.dim() - 2, X.dim() - 1)

            if pointnet:
                _, feature_transform, codes = encoder(X)
            else:
                codes, _, _ = encoder(X)

            # discriminator training
            noise = torch.empty(codes.shape[0], config['z_size']).normal_(mean=config['normal_mu'],
                                                                          std=config['normal_std']).to(device)
            synth_logit = discriminator(codes)
            real_logit = discriminator(noise)
            if config.get('wasserstein', True):
                loss_discriminator = torch.mean(synth_logit) - torch.mean(real_logit)

                alpha = torch.rand(codes.shape[0], 1).to(device)
                differences = codes - noise
                interpolates = noise + alpha * differences
                disc_interpolates = discriminator(interpolates)

                # gradient_penalty_function
                gradients = grad(
                    outputs=disc_interpolates,
                    inputs=interpolates,
                    grad_outputs=torch.ones_like(disc_interpolates).to(device),
                    create_graph=True,
                    retain_graph=True,
                    only_inputs=True)[0]
                slopes = torch.sqrt(torch.sum(gradients ** 2, dim=1))
                gradient_penalty = ((slopes - 1) ** 2).mean()
                loss_gp = config['gradient_penalty_coef'] * gradient_penalty
                loss_discriminator += loss_gp
            else:
                # An alternative is a = -1, b = 1 iff c = 0
                a = 0.0
                b = 1.0
                loss_discriminator = 0.5 * ((real_logit - b)**2 + (synth_logit - a)**2)

            discriminator_optimizer.zero_grad()
            discriminator.zero_grad()

            loss_discriminator.backward(retain_graph=True)
            total_loss_discriminator += loss_discriminator.item()
            discriminator_optimizer.step()

            # hyper network training
            target_networks_weights = hyper_network(codes)

            X_rec = torch.zeros(X.shape).to(device)
            for j, target_network_weights in enumerate(target_networks_weights):
                target_network = aae.TargetNetwork(config, target_network_weights).to(device)

                if not config['target_network_input']['constant'] or target_network_input is None:
                    target_network_input = generate_points(config=config, epoch=epoch, size=(X.shape[2], X.shape[1]))

                X_rec[j] = torch.transpose(target_network(target_network_input.to(device)), 0, 1)

            if pointnet:
                loss_reconstruction = config['reconstruction_coef'] * \
                                      reconstruction_loss(torch.transpose(X, 1, 2).contiguous(),
                                                          torch.transpose(X_rec, 1, 2).contiguous(),
                                                          batch_size=X.shape[0]).mean()
            else:
                loss_reconstruction = torch.mean(
                    config['reconstruction_coef'] *
                    reconstruction_loss(X.permute(0, 2, 1) + 0.5,
                                        X_rec.permute(0, 2, 1) + 0.5))

            # encoder training
            synth_logit = discriminator(codes)
            if config.get('wasserstein', True):
                loss_encoder = -torch.mean(synth_logit)
            else:
                # An alternative is c = 0 iff a = -1, b = 1
                c = 1.0
                loss_encoder = 0.5 * (synth_logit - c)**2

            if pointnet:
                regularization_loss = config['feature_regularization_coef'] * \
                                      feature_transform_regularization(feature_transform).mean()
                loss_all = loss_reconstruction + loss_encoder + regularization_loss
            else:
                loss_all = loss_reconstruction + loss_encoder

            e_hn_optimizer.zero_grad()
            encoder.zero_grad()
            hyper_network.zero_grad()

            loss_all.backward()
            e_hn_optimizer.step()

            total_loss_reconstruction += loss_reconstruction.item()
            total_loss_encoder += loss_encoder.item()
            total_loss_all += loss_all.item()

            if pointnet:
                total_loss_regularization += regularization_loss.item()

        log.info(
            f'[{epoch}/{config["max_epochs"]}] '
            f'Total_Loss: {total_loss_all / i:.4f} '
            f'Loss_R: {total_loss_reconstruction / i:.4f} '
            f'Loss_E: {total_loss_encoder / i:.4f} '
            f'Loss_D: {total_loss_discriminator / i:.4f} '
            f'Time: {datetime.now() - start_epoch_time}'
        )

        if pointnet:
            log.info(f'Loss_Regularization: {total_loss_regularization / i:.4f}')

        losses_e.append(total_loss_reconstruction)
        losses_g.append(total_loss_encoder)
        losses_eg.append(total_loss_all)
        losses_d.append(total_loss_discriminator)

        #
        # Save intermediate results
        #
        if epoch % config['save_samples_frequency'] == 0:
            log.debug('Saving samples...')

            X = X.cpu().numpy()
            X_rec = X_rec.detach().cpu().numpy()

            for k in range(min(5, X_rec.shape[0])):
                fig = plot_3d_point_cloud(X_rec[k][0], X_rec[k][1], X_rec[k][2], in_u_sphere=True, show=False,
                                          title=str(epoch))
                fig.savefig(join(results_dir, 'samples', f'{epoch}_{k}_reconstructed.png'))
                plt.close(fig)

                fig = plot_3d_point_cloud(X[k][0], X[k][1], X[k][2], in_u_sphere=True, show=False)
                fig.savefig(join(results_dir, 'samples', f'{epoch}_{k}_real.png'))
                plt.close(fig)

        if config['clean_weights_dir']:
            log.debug('Cleaning weights path: %s' % weights_path)
            shutil.rmtree(weights_path, ignore_errors=True)
            os.makedirs(weights_path, exist_ok=True)

        if epoch % config['save_weights_frequency'] == 0:
            log.debug('Saving weights and losses...')

            torch.save(hyper_network.state_dict(), join(weights_path, f'{epoch:05}_G.pth'))
            torch.save(encoder.state_dict(), join(weights_path, f'{epoch:05}_E.pth'))
            torch.save(e_hn_optimizer.state_dict(), join(weights_path, f'{epoch:05}_EGo.pth'))
            torch.save(discriminator.state_dict(), join(weights_path, f'{epoch:05}_D.pth'))
            torch.save(discriminator_optimizer.state_dict(), join(weights_path, f'{epoch:05}_Do.pth'))

            np.save(join(metrics_path, f'{epoch:05}_E'), np.array(losses_e))
            np.save(join(metrics_path, f'{epoch:05}_G'), np.array(losses_g))
            np.save(join(metrics_path, f'{epoch:05}_EG'), np.array(losses_eg))
            np.save(join(metrics_path, f'{epoch:05}_D'), np.array(losses_d))
예제 #4
0
def main(config):
    set_seed(config['seed'])

    results_dir = prepare_results_dir(config, 'vae', 'training')
    starting_epoch = find_latest_epoch(results_dir) + 1

    if not exists(join(results_dir, 'config.json')):
        with open(join(results_dir, 'config.json'), mode='w') as f:
            json.dump(config, f)

    setup_logging(results_dir)
    log = logging.getLogger('vae')

    device = cuda_setup(config['cuda'], config['gpu'])
    log.info(f'Device variable: {device}')
    if device.type == 'cuda':
        log.info(f'Current CUDA device: {torch.cuda.current_device()}')

    weights_path = join(results_dir, 'weights')
    metrics_path = join(results_dir, 'metrics')

    #
    # Dataset
    #
    dataset_name = config['dataset'].lower()
    if dataset_name == 'shapenet':
        from datasets.shapenet import ShapeNetDataset
        dataset = ShapeNetDataset(root_dir=config['data_dir'],
                                  classes=config['classes'])
    elif dataset_name == 'custom':
        dataset = TxtDataset(root_dir=config['data_dir'],
                             classes=config['classes'],
                             config=config)
    elif dataset_name == 'benchmark':
        dataset = Benchmark(root_dir=config['data_dir'],
                            classes=config['classes'],
                            config=config)
    else:
        raise ValueError(f'Invalid dataset name. Expected `shapenet` or '
                         f'`faust`. Got: `{dataset_name}`')

    log.info("Selected {} classes. Loaded {} samples.".format(
        'all' if not config['classes'] else ','.join(config['classes']),
        len(dataset)))

    points_dataloader = DataLoader(dataset,
                                   batch_size=config['batch_size'],
                                   shuffle=config['shuffle'],
                                   num_workers=config['num_workers'],
                                   drop_last=True,
                                   pin_memory=True,
                                   collate_fn=collate_fn)

    #
    # Models
    #
    hyper_network = aae.HyperNetwork(config, device).to(device)
    encoder_pocket = aae.PocketEncoder(config).to(device)
    encoder_visible = aae.VisibleEncoder(config).to(device)

    hyper_network.apply(weights_init)
    encoder_pocket.apply(weights_init)
    encoder_visible.apply(weights_init)

    if config['reconstruction_loss'].lower() == 'chamfer':
        from losses.champfer_loss import ChamferLoss
        reconstruction_loss = ChamferLoss().to(device)
    elif config['reconstruction_loss'].lower() == 'earth_mover':
        # from utils.metrics import earth_mover_distance
        # reconstruction_loss = earth_mover_distance
        from losses.earth_mover_distance import EMD
        reconstruction_loss = EMD().to(device)
    else:
        raise ValueError(
            f'Invalid reconstruction loss. Accepted `chamfer` or '
            f'`earth_mover`, got: {config["reconstruction_loss"]}')

    #
    # Optimizers
    #
    e_hn_optimizer = getattr(optim, config['optimizer']['E_HN']['type'])
    e_hn_optimizer = e_hn_optimizer(
        chain(encoder_visible.parameters(), encoder_pocket.parameters(),
              hyper_network.parameters()),
        **config['optimizer']['E_HN']['hyperparams'])

    log.info("Starting epoch: %s" % starting_epoch)
    if starting_epoch > 1:
        log.info("Loading weights...")
        hyper_network.load_state_dict(
            torch.load(join(weights_path, f'{starting_epoch - 1:05}_G.pth')))
        encoder_pocket.load_state_dict(
            torch.load(join(weights_path, f'{starting_epoch - 1:05}_EP.pth')))
        encoder_visible.load_state_dict(
            torch.load(join(weights_path, f'{starting_epoch - 1:05}_EV.pth')))

        e_hn_optimizer.load_state_dict(
            torch.load(join(weights_path, f'{starting_epoch - 1:05}_EGo.pth')))

        log.info("Loading losses...")
        losses_e = np.load(join(metrics_path,
                                f'{starting_epoch - 1:05}_E.npy')).tolist()
        losses_kld = np.load(
            join(metrics_path, f'{starting_epoch - 1:05}_KLD.npy')).tolist()
        losses_eg = np.load(
            join(metrics_path, f'{starting_epoch - 1:05}_EG.npy')).tolist()
    else:
        log.info("First epoch")
        losses_e = []
        losses_kld = []
        losses_eg = []

    if config['target_network_input']['normalization']['enable']:
        normalization_type = config['target_network_input']['normalization'][
            'type']
        assert normalization_type == 'progressive', 'Invalid normalization type'

    target_network_input = None
    for epoch in range(starting_epoch, config['max_epochs'] + 1):
        start_epoch_time = datetime.now()
        log.debug("Epoch: %s" % epoch)
        hyper_network.train()
        encoder_visible.train()
        encoder_pocket.train()

        total_loss_all = 0.0
        total_loss_r = 0.0
        total_loss_kld = 0.0
        for i, point_data in enumerate(points_dataloader, 1):
            # get only visible part of point cloud
            X = point_data['non-visible']
            X = X.to(device, dtype=torch.float)

            # get not visible part of point cloud
            X_visible = point_data['visible']
            X_visible = X_visible.to(device, dtype=torch.float)

            # get whole point cloud
            X_whole = point_data['cloud']
            X_whole = X_whole.to(device, dtype=torch.float)

            # Change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS]
            if X.size(-1) == 3:
                X.transpose_(X.dim() - 2, X.dim() - 1)
                X_visible.transpose_(X_visible.dim() - 2, X_visible.dim() - 1)
                X_whole.transpose_(X_whole.dim() - 2, X_whole.dim() - 1)

            codes, mu, logvar = encoder_pocket(X)
            mu_visible = encoder_visible(X_visible)

            target_networks_weights = hyper_network(
                torch.cat((codes, mu_visible), 1))

            X_rec = torch.zeros(X_whole.shape).to(device)
            for j, target_network_weights in enumerate(
                    target_networks_weights):
                target_network = aae.TargetNetwork(
                    config, target_network_weights).to(device)

                if not config['target_network_input'][
                        'constant'] or target_network_input is None:
                    target_network_input = generate_points(
                        config=config,
                        epoch=epoch,
                        size=(X_whole.shape[2], X_whole.shape[1]))

                X_rec[j] = torch.transpose(
                    target_network(target_network_input.to(device)), 0, 1)

            loss_r = torch.mean(config['reconstruction_coef'] *
                                reconstruction_loss(
                                    X_whole.permute(0, 2, 1) + 0.5,
                                    X_rec.permute(0, 2, 1) + 0.5))

            loss_kld = 0.5 * (torch.exp(logvar) + torch.pow(mu, 2) - 1 -
                              logvar).sum()

            loss_all = loss_r + loss_kld
            e_hn_optimizer.zero_grad()
            encoder_visible.zero_grad()
            encoder_pocket.zero_grad()
            hyper_network.zero_grad()

            loss_all.backward()
            e_hn_optimizer.step()

            total_loss_r += loss_r.item()
            total_loss_kld += loss_kld.item()
            total_loss_all += loss_all.item()

        log.info(f'[{epoch}/{config["max_epochs"]}] '
                 f'Loss_ALL: {total_loss_all / i:.4f} '
                 f'Loss_R: {total_loss_r / i:.4f} '
                 f'Loss_E: {total_loss_kld / i:.4f} '
                 f'Time: {datetime.now() - start_epoch_time}')

        losses_e.append(total_loss_r)
        losses_kld.append(total_loss_kld)
        losses_eg.append(total_loss_all)

        #
        # Save intermediate results
        #
        X = X.cpu().numpy()
        X_whole = X_whole.cpu().numpy()
        X_rec = X_rec.detach().cpu().numpy()

        if epoch % config['save_frequency'] == 0:
            for k in range(min(5, X_rec.shape[0])):
                fig = plot_3d_point_cloud(X_rec[k][0],
                                          X_rec[k][1],
                                          X_rec[k][2],
                                          in_u_sphere=True,
                                          show=False,
                                          title=str(epoch))
                fig.savefig(
                    join(results_dir, 'samples',
                         f'{epoch}_{k}_reconstructed.png'))
                plt.close(fig)

                fig = plot_3d_point_cloud(X_whole[k][0],
                                          X_whole[k][1],
                                          X_whole[k][2],
                                          in_u_sphere=True,
                                          show=False,
                                          title=str(epoch))
                fig.savefig(
                    join(results_dir, 'samples', f'{epoch}_{k}_real.png'))
                plt.close(fig)

                fig = plot_3d_point_cloud(X[k][0],
                                          X[k][1],
                                          X[k][2],
                                          in_u_sphere=True,
                                          show=False)
                fig.savefig(
                    join(results_dir, 'samples', f'{epoch}_{k}_visible.png'))
                plt.close(fig)

        if config['clean_weights_dir']:
            log.debug('Cleaning weights path: %s' % weights_path)
            shutil.rmtree(weights_path, ignore_errors=True)
            os.makedirs(weights_path, exist_ok=True)

        if epoch % config['save_frequency'] == 0:
            log.debug('Saving data...')

            torch.save(hyper_network.state_dict(),
                       join(weights_path, f'{epoch:05}_G.pth'))
            torch.save(encoder_visible.state_dict(),
                       join(weights_path, f'{epoch:05}_EV.pth'))
            torch.save(encoder_pocket.state_dict(),
                       join(weights_path, f'{epoch:05}_EP.pth'))
            torch.save(e_hn_optimizer.state_dict(),
                       join(weights_path, f'{epoch:05}_EGo.pth'))

            np.save(join(metrics_path, f'{epoch:05}_E'), np.array(losses_e))
            np.save(join(metrics_path, f'{epoch:05}_KLD'),
                    np.array(losses_kld))
            np.save(join(metrics_path, f'{epoch:05}_EG'), np.array(losses_eg))
예제 #5
0
def dist_chamfer(x, y):
    from losses.champfer_loss import ChamferLoss
    from utils.util import cuda_setup
    chamfer_loss = ChamferLoss().to(cuda_setup())
    P = chamfer_loss.batch_pairwise_dist(x, y)
    return P.min(1)[0], P.min(2)[0]
예제 #6
0
def main(config):
    random.seed(config['seed'])
    torch.manual_seed(config['seed'])
    torch.cuda.manual_seed_all(config['seed'])

    results_dir = prepare_results_dir(config)
    starting_epoch = find_latest_epoch(results_dir) + 1

    if not exists(join(results_dir, 'config.json')):
        with open(join(results_dir, 'config.json'), mode='w') as f:
            json.dump(config, f)

    setup_logging(results_dir)
    log = logging.getLogger(__name__)

    device = cuda_setup(config['cuda'], config['gpu'])
    log.debug(f'Device variable: {device}')
    if device.type == 'cuda':
        log.debug(f'Current CUDA device: {torch.cuda.current_device()}')

    weights_path = join(results_dir, 'weights')

    #
    # Dataset
    #
    dataset_name = config['dataset'].lower()
    if dataset_name == 'shapenet':
        dataset = ShapeNetDataset(root_dir=config['data_dir'],
                                  classes=config['classes'])
    else:
        raise ValueError(f'Invalid dataset name. Expected `shapenet` or '
                         f'`faust`. Got: `{dataset_name}`')
    log.debug("Selected {} classes. Loaded {} samples.".format(
        'all' if not config['classes'] else ','.join(config['classes']),
        len(dataset)))

    points_dataloader = DataLoader(dataset,
                                   batch_size=config['batch_size'],
                                   shuffle=config['shuffle'],
                                   num_workers=config['num_workers'],
                                   drop_last=True,
                                   pin_memory=True)

    #
    # Models
    #
    arch = import_module(f"model.architectures.{config['arch']}")
    G = arch.Generator(config).to(device)
    E = arch.Encoder(config).to(device)

    G.apply(weights_init)
    E.apply(weights_init)

    if config['reconstruction_loss'].lower() == 'chamfer':
        reconstruction_loss = ChamferLoss().to(device)
    elif config['reconstruction_loss'].lower() == 'earth_mover':
        reconstruction_loss = EMD().to(device)
    else:
        raise ValueError(
            f'Invalid reconstruction loss. Accepted `chamfer` or '
            f'`earth_mover`, got: {config["reconstruction_loss"]}')

    #
    # Optimizers
    #
    EG_optim = getattr(optim, config['optimizer']['EG']['type'])
    EG_optim = EG_optim(chain(E.parameters(), G.parameters()),
                        **config['optimizer']['EG']['hyperparams'])

    if starting_epoch > 1:
        G.load_state_dict(
            torch.load(join(weights_path, f'{starting_epoch-1:05}_G.pth')))
        E.load_state_dict(
            torch.load(join(weights_path, f'{starting_epoch-1:05}_E.pth')))

        EG_optim.load_state_dict(
            torch.load(join(weights_path, f'{starting_epoch-1:05}_EGo.pth')))

    for epoch in range(starting_epoch, config['max_epochs'] + 1):
        start_epoch_time = datetime.now()

        G.train()
        E.train()

        total_loss = 0.0
        for i, point_data in enumerate(points_dataloader, 1):
            log.debug('-' * 20)

            X, _ = point_data
            X = X.to(device)

            # Change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS]
            if X.size(-1) == 3:
                X.transpose_(X.dim() - 2, X.dim() - 1)

            X_rec = G(E(X))

            loss = torch.mean(config['reconstruction_coef'] *
                              reconstruction_loss(
                                  X.permute(0, 2, 1) + 0.5,
                                  X_rec.permute(0, 2, 1) + 0.5))

            EG_optim.zero_grad()
            E.zero_grad()
            G.zero_grad()

            loss.backward()
            total_loss += loss.item()
            EG_optim.step()

            log.debug(f'[{epoch}: ({i})] '
                      f'Loss: {loss.item():.4f} '
                      f'Time: {datetime.now() - start_epoch_time}')

        log.debug(f'[{epoch}/{config["max_epochs"]}] '
                  f'Loss: {total_loss / i:.4f} '
                  f'Time: {datetime.now() - start_epoch_time}')

        #
        # Save intermediate results
        #
        G.eval()
        E.eval()
        with torch.no_grad():
            X_rec = G(E(X)).data.cpu().numpy()

        for k in range(5):
            fig = plot_3d_point_cloud(X[k][0],
                                      X[k][1],
                                      X[k][2],
                                      in_u_sphere=True,
                                      show=False,
                                      title=str(epoch))
            fig.savefig(
                join(results_dir, 'samples', f'{epoch:05}_{k}_real.png'))
            plt.close(fig)

        for k in range(5):
            fig = plot_3d_point_cloud(X_rec[k][0],
                                      X_rec[k][1],
                                      X_rec[k][2],
                                      in_u_sphere=True,
                                      show=False,
                                      title=str(epoch))
            fig.savefig(
                join(results_dir, 'samples',
                     f'{epoch:05}_{k}_reconstructed.png'))
            plt.close(fig)

        if epoch % config['save_frequency'] == 0:
            torch.save(G.state_dict(), join(weights_path, f'{epoch:05}_G.pth'))
            torch.save(E.state_dict(), join(weights_path, f'{epoch:05}_E.pth'))

            torch.save(EG_optim.state_dict(),
                       join(weights_path, f'{epoch:05}_EGo.pth'))
예제 #7
0
def main(config):
    global results_dir, max_epochs, batch_size

    # seeds
    random.seed(2019)
    torch.manual_seed(2019)
    torch.cuda.manual_seed_all(2019)


    #setting directory to save results and weights
    results_dir = join(results_dir , experiment)
    results_dir = prepare_results_dir(results_dir, b_clean=False)
    weights_path = join(results_dir, 'weights')

    # find latest epoch
    # starting_epoch = find_latest_epoch(results_dir) + 1


    device = cuda_setup(True, 0)

    log = logging.getLogger(__name__)

    dataset = VesselDataset3()


    points_dataloader = DataLoader(dataset,batch_size= batch_size, shuffle = True, num_workers = 8, drop_last=True, pin_memory=True)
    noise = tf.placeholder(tf.float32, [None, n_points, 3])
    

    G = Generator().to(device)
    E = Encoder().to(device)

    G.apply(weights_init)
    E.apply(weights_init)


    reconstruction_loss = ChamferLoss().to(device)


    EG_optim = torch.optim.Adam(chain(E.parameters(), G.parameters()), lr= 0.0005, weight_decay= 0, betas= [0.9, 0.999],amsgrad= False)

    # load_latest_epoch(E, G, EG_optim, weights_path ,starting_epoch)



    for epoch in range(0, max_epochs ):
        start_epoch_time = datetime.now()

        G.train()
        E.train()

        total_loss = 0.0
        for i, point_data in enumerate(points_dataloader, 0):
            F, X = point_data
            X = X.to(device)
            F = F.to(device)
            # print(X.shape)

            # Change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS]
            if X.size(-1) == 3:
                X.transpose_(X.dim() - 2, X.dim() - 1)

            # print("features shape: ", F.shape)
            # print("points shape: ", X.shape)
            
            X_rec = G(E(F))
            # print("reconstructed points shape: ", X_rec.shape)

            loss = torch.mean(0.05 * reconstruction_loss(X.permute(0, 2, 1) + 0.5, X_rec.permute(0, 2, 1) + 0.5))

            EG_optim.zero_grad()
            E.zero_grad()
            G.zero_grad()

            loss.backward()
            total_loss += loss.item()

            EG_optim.step()


            print(f'[{epoch}: ({i})] '
                      f'Loss: {loss.item():.4f} '
                      f'Time: {datetime.now() - start_epoch_time}')
        log.debug(
            f'[{epoch}/{max_epochs}] '
            f'Loss: {total_loss / i:.4f} '
            f'Time: {datetime.now() - start_epoch_time}'
        )
        
        G.eval()
        E.eval()

        with torch.no_grad():
            X_rec = G(E(F)).data.cpu().numpy()

        X_cpu = X.cpu().numpy()
        print(X_rec.min(axis=0), X_rec.max(axis=0))
        print(X_cpu.min(axis=0), X_cpu.max(axis=0))
def main(config):
    global results_dir, max_epochs, batch_size

    # setting seeds
    random.seed(2019)
    torch.manual_seed(2019)
    torch.cuda.manual_seed_all(2019)

    #setting directory to save results and weights
    results_dir = join(results_dir, experiment)
    results_dir = prepare_results_dir(results_dir, b_clean=False)
    weights_path = join(results_dir, 'weights')

    #finding last saved epoch if exists in resutls directory
    starting_epoch = find_latest_epoch(results_dir) + 1

    #setting device for pytorch usage
    device = cuda_setup(True, 0)

    #use to log useful information
    log = logging.getLogger(__name__)

    #load vessels dataset
    dataset = VesselDataset2(
        root_dir=
        "/home/texs/Documents/Repositorios/point_cloud_reconstruction/data/ModelNet10"
    )
    points_dataloader = DataLoader(dataset,
                                   batch_size=batch_size,
                                   shuffle=True,
                                   num_workers=8,
                                   drop_last=True,
                                   pin_memory=True)

    #loading models and weights
    G = Generator().to(device)
    E = Encoder().to(device)
    G.apply(weights_init)
    E.apply(weights_init)

    # setting reconstruction loss
    reconstruction_loss = ChamferLoss().to(device)
    # reconstruction_loss = EMD().to(device)
    # reconstruction_loss = EMD().to(device)

    #optimization in models parameters
    EG_optim = torch.optim.Adam(chain(E.parameters(), G.parameters()),
                                lr=0.0005,
                                weight_decay=0,
                                betas=[0.9, 0.999],
                                amsgrad=False)

    #loading weights if they exists in results directory
    load_latest_epoch(E, G, EG_optim, weights_path, starting_epoch)

    # training
    for epoch in range(starting_epoch, max_epochs):
        start_epoch_time = datetime.now()

        G.train()
        E.train()

        total_loss = 0.0
        for i, point_data in enumerate(points_dataloader, 0):
            X, _ = point_data
            X = X.to(device)

            # Change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS]
            if X.size(-1) == 3:
                X.transpose_(X.dim() - 2, X.dim() - 1)

            X_rec = G(E(X))

            loss = torch.mean(0.05 * reconstruction_loss(
                X.permute(0, 2, 1) + 0.5,
                X_rec.permute(0, 2, 1) + 0.5))

            EG_optim.zero_grad()
            E.zero_grad()
            G.zero_grad()

            loss.backward()
            total_loss += loss.item()

            EG_optim.step()

            print(f'[{epoch}: ({i})] '
                  f'Loss: {loss.item():.4f} '
                  f'Time: {datetime.now() - start_epoch_time}')
        log.debug(f'[{epoch}/{max_epochs}] '
                  f'Loss: {total_loss / i:.4f} '
                  f'Time: {datetime.now() - start_epoch_time}')

        G.eval()
        E.eval()

        with torch.no_grad():
            X_rec = G(E(X)).data.cpu().numpy()

        X_cpu = X.cpu().numpy()

        save_point_cloud(X_cpu, X_rec, epoch, n_fig=5, results_dir=results_dir)

        if epoch % save_frequency == 0:
            save_weights(E, G, EG_optim, weights_path, epoch)
예제 #9
0
def main(config):
    random.seed(config['seed'])
    torch.manual_seed(config['seed'])
    torch.cuda.manual_seed_all(config['seed'])

    results_dir = prepare_results_dir(config)
    starting_epoch = find_latest_epoch(results_dir) + 1

    if not exists(join(results_dir, 'config.json')):
        with open(join(results_dir, 'config.json'), mode='w') as f:
            json.dump(config, f)

    setup_logging(results_dir)
    log = logging.getLogger('vae')

    device = cuda_setup(config['cuda'], config['gpu'])
    log.debug(f'Device variable: {device}')
    if device.type == 'cuda':
        log.debug(f'Current CUDA device: {torch.cuda.current_device()}')

    weights_path = join(results_dir, 'weights')

    #
    # Dataset
    #
    dataset_name = config['dataset'].lower()
    if dataset_name == 'shapenet':
        from datasets.shapenet import ShapeNetDataset
        dataset = ShapeNetDataset(root_dir=config['data_dir'],
                                  classes=config['classes'])
    else:
        raise ValueError(f'Invalid dataset name. Expected `shapenet` or '
                         f'`faust`. Got: `{dataset_name}`')

    log.debug("Selected {} classes. Loaded {} samples.".format(
        'all' if not config['classes'] else ','.join(config['classes']),
        len(dataset)))

    points_dataloader = DataLoader(dataset, batch_size=config['batch_size'],
                                   shuffle=config['shuffle'],
                                   num_workers=config['num_workers'],
                                   drop_last=True, pin_memory=True)

    #
    # Models
    #
    arch = import_module(f"models.{config['arch']}")
    G = arch.Generator(config).to(device)
    E = arch.Encoder(config).to(device)

    G.apply(weights_init)
    E.apply(weights_init)

    if config['reconstruction_loss'].lower() == 'chamfer':
        from losses.champfer_loss import ChamferLoss
        reconstruction_loss = ChamferLoss().to(device)
    elif config['reconstruction_loss'].lower() == 'earth_mover':
        from losses.earth_mover_distance import EMD
        reconstruction_loss = EMD().to(device)
    elif config['reconstruction_loss'].lower() == 'cramer_wold':
        from losses.cramer_wold import CWSample
        reconstruction_loss = CWSample().to(device)
    else:
        raise ValueError(f'Invalid reconstruction loss. Accepted `chamfer` or '
                         f'`earth_mover`, got: {config["reconstruction_loss"]}')
    #
    # Float Tensors
    #
    fixed_noise = torch.FloatTensor(config['batch_size'], config['z_size'], 1)
    fixed_noise.normal_(mean=0, std=0.2)
    std_assumed = torch.tensor(0.2)

    fixed_noise = fixed_noise.to(device)
    std_assumed = std_assumed.to(device)

    #
    # Optimizers
    #
    EG_optim = getattr(optim, config['optimizer']['EG']['type'])
    EG_optim = EG_optim(chain(E.parameters(), G.parameters()),
                        **config['optimizer']['EG']['hyperparams'])

    if starting_epoch > 1:
        G.load_state_dict(torch.load(
            join(weights_path, f'{starting_epoch-1:05}_G.pth')))
        E.load_state_dict(torch.load(
            join(weights_path, f'{starting_epoch-1:05}_E.pth')))

        EG_optim.load_state_dict(torch.load(
            join(weights_path, f'{starting_epoch-1:05}_EGo.pth')))

    losses = []

    with trange(starting_epoch, config['max_epochs'] + 1) as t:
        for epoch in t:
            start_epoch_time = datetime.now()

            G.train()
            E.train()

            total_loss = 0.0
            losses_eg = []
            losses_e = []
            losses_kld = []

            for i, point_data in enumerate(points_dataloader, 1):
                # log.debug('-' * 20)

                X, _ = point_data
                X = X.to(device)

                # Change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS]
                if X.size(-1) == 3:
                    X.transpose_(X.dim() - 2, X.dim() - 1)

                codes, mu, logvar = E(X)
                X_rec = G(codes)

                loss_e = torch.mean(
                    # config['reconstruction_coef'] *
                    reconstruction_loss(X.permute(0, 2, 1) + 0.5,
                                        X_rec.permute(0, 2, 1) + 0.5))

                loss_kld = config['reconstruction_coef'] * cw_distance(mu)

                # loss_kld = -0.5 * torch.mean(
                #     1 - 2.0 * torch.log(std_assumed) + logvar -
                #     (mu.pow(2) + logvar.exp()) / torch.pow(std_assumed, 2))

                loss_eg = loss_e + loss_kld
                EG_optim.zero_grad()
                E.zero_grad()
                G.zero_grad()

                loss_eg.backward()
                total_loss += loss_eg.item()
                EG_optim.step()

                losses_e.append(loss_e.item())
                losses_kld.append(loss_kld.item())
                losses_eg.append(loss_eg.item())

                # log.debug

                t.set_description(
                    f'[{epoch}: ({i})] '
                    f'Loss_EG: {loss_eg.item():.4f} '
                    f'(REC: {loss_e.item(): .4f}'
                    f' KLD: {loss_kld.item(): .4f})'
                    f' Time: {datetime.now() - start_epoch_time}'
                )

            t.set_description(
                f'[{epoch}/{config["max_epochs"]}] '
                f'Loss_G: {total_loss / i:.4f} '
                f'Time: {datetime.now() - start_epoch_time}'
            )

            losses.append([
                np.mean(losses_e),
                np.mean(losses_kld),
                np.mean(losses_eg)
            ])

            #
            # Save intermediate results
            #
            G.eval()
            E.eval()
            with torch.no_grad():
                fake = G(fixed_noise).data.cpu().numpy()
                codes, _, _ = E(X)
                X_rec = G(codes).data.cpu().numpy()

            X_numpy = X.cpu().numpy()
            for k in range(5):
                fig = plot_3d_point_cloud(X_numpy[k][0], X_numpy[k][1], X_numpy[k][2],
                                          in_u_sphere=True, show=False)
                fig.savefig(
                    join(results_dir, 'samples', f'{epoch}_{k}_real.png'))
                plt.close(fig)

            for k in range(5):
                fig = plot_3d_point_cloud(fake[k][0], fake[k][1], fake[k][2],
                                          in_u_sphere=True, show=False,
                                          title=str(epoch))
                fig.savefig(
                    join(results_dir, 'samples', 'fixed', f'{epoch:05}_{k}_fixed.png'))
                plt.close(fig)

            for k in range(5):
                fig = plot_3d_point_cloud(X_rec[k][0],
                                          X_rec[k][1],
                                          X_rec[k][2],
                                          in_u_sphere=True, show=False)
                fig.savefig(join(results_dir, 'samples',
                                 f'{epoch}_{k}_reconstructed.png'))
                plt.close(fig)

            if epoch % config['save_frequency'] == 0:
                df = pd.DataFrame(losses, columns=['loss_e', 'loss_kld', 'loss_eg'])
                df.to_json(join(results_dir, 'losses', 'losses.json'))
                fig = df.plot.line().get_figure()
                fig.savefig(join(results_dir, 'losses', f'{epoch:05}_{k}.png'))

                torch.save(G.state_dict(), join(weights_path, f'{epoch:05}_G.pth'))
                torch.save(E.state_dict(), join(weights_path, f'{epoch:05}_E.pth'))

                torch.save(EG_optim.state_dict(),
                           join(weights_path, f'{epoch:05}_EGo.pth'))
예제 #10
0
def main(config: dict):
    # region Setup
    seed_setup(config['setup']['seed'])

    run_mode: str = config['mode']
    result_dir_path: str = get_results_dir_path(config, run_mode)

    if run_mode == 'training':
        dirs_to_create = ('weights', 'samples', 'metrics')
        weights_path = join(result_dir_path, 'weights')
        metrics_path = join(result_dir_path, 'metrics')
    elif run_mode == 'experiments':
        dirs_to_create = tuple(experiment_functions_dict.keys())
        weights_path = join(get_results_dir_path(config, 'training'),
                            'weights')
        metrics_path = join(get_results_dir_path(config, 'training'),
                            'metrics')
    else:
        raise ValueError("mode should be `training` or `experiments`")

    results_dir_setup(result_dir_path, dirs_to_create)

    with open(join(result_dir_path, 'last_config.json'), mode='w') as f:
        json.dump(config, f)

    logging_setup(result_dir_path)
    log = logging.getLogger()

    log.info(f'Current mode {run_mode}')

    if config['telegram_logger']['enable']:
        tg_log = TelegramLogger.get_logger(config['telegram_logger'])

    device = cuda_setup(config['setup']['gpu_id'])
    log.info(f'Device variable: {device}')

    reconstruction_loss = ChamferLoss().to(device)
    full_model = FullModel(config['full_model']).to(device)
    full_model.apply(weights_init)

    optimizer = getattr(optim,
                        config['training']['optimizer']['type'])  # class
    optimizer = optimizer(full_model.parameters(),
                          **config['training']['optimizer']['hyperparams'])

    scheduler = getattr(optim.lr_scheduler,
                        config['training']['lr_scheduler']['type'])  # class
    scheduler = scheduler(optimizer,
                          **config['training']['lr_scheduler']['hyperparams'])
    log.info(f'Model {get_model_name(config)} created')

    latest_epoch = find_latest_epoch(result_dir_path if run_mode ==
                                     "training" else weights_path)

    log.info(f'Latest epoch found: {latest_epoch}')

    if latest_epoch > 0:
        if run_mode == "training":
            latest_epoch = restore_model_state(weights_path, metrics_path,
                                               config['setup']['gpu_id'],
                                               latest_epoch, "latest",
                                               full_model, optimizer,
                                               scheduler)
        elif run_mode == "experiments":
            latest_epoch = restore_model_state(weights_path, metrics_path,
                                               config['setup']['gpu_id'],
                                               latest_epoch,
                                               config['experiments']['epoch'],
                                               full_model)
        log.info(f'Restored epoch : {latest_epoch}')
    elif run_mode == "experiments":
        raise FileNotFoundError("no weights found at ", weights_path)
    # endregion Setup

    train_dataset, val_dataset_dict, test_dataset_dict = get_datasets(
        config['dataset'])

    log.info(
        f'Dataset loaded for classes: {[cat_name for cat_name in val_dataset_dict.keys()]}'
    )

    if run_mode == 'training':
        samples_path = join(result_dir_path, 'samples')
        train_dataloader = DataLoader(
            train_dataset,
            pin_memory=True,
            **config['training']['dataloader']['train'])
        val_dataloaders_dict = {
            cat_name: DataLoader(cat_ds,
                                 pin_memory=True,
                                 **config['training']['dataloader']['val'])
            for cat_name, cat_ds in val_dataset_dict.items()
        }
        if latest_epoch == 0:
            best_epoch_loss = np.Infinity
            train_losses = []
            val_losses = []
        else:
            train_losses, val_losses, best_epoch_loss = restore_metrics(
                metrics_path, latest_epoch)

        for epoch in range(latest_epoch + 1,
                           config['training']['max_epoch'] + 1):
            start_epoch_time = datetime.now()
            log.debug("Epoch: %s" % epoch)

            full_model, optimizer, epoch_loss_all, epoch_loss_kld, epoch_loss_r, latest_existing, latest_gt, latest_rec \
                = train_epoch(epoch, full_model, optimizer, train_dataloader, device, reconstruction_loss,
                              config['training']['loss_coef'])
            scheduler.step()

            train_losses.append(
                np.array([epoch_loss_all, epoch_loss_r, epoch_loss_kld]))

            log_string = f'[{epoch}/{config["training"]["max_epoch"]}] ' \
                         f'Loss_ALL: {epoch_loss_all:.4f} ' \
                         f'Loss_R: {epoch_loss_r:.4f} ' \
                         f'Loss_E: {epoch_loss_kld:.4f} ' \
                         f'Time: {datetime.now() - start_epoch_time}'
            log.info(log_string)

            train_plots = []
            for k in range(min(5, latest_rec.shape[0])):
                train_plots.append(
                    save_plot(latest_existing[k], epoch, k, samples_path,
                              'existing'))
                train_plots.append(
                    save_plot(latest_rec[k], epoch, k, samples_path,
                              'reconstructed'))
                train_plots.append(
                    save_plot(latest_gt[k].T, epoch, k, samples_path, 'gt'))

            if config['telegram_logger']['enable']:
                tg_log.log_images(train_plots[:9], log_string)

            epoch_val_losses, epoch_val_samples = val_epoch(
                epoch, full_model, device, val_dataloaders_dict,
                val_dataset_dict.keys(), reconstruction_loss,
                config['training']['loss_coef'])

            is_new_best = epoch_val_losses['total'][0] < best_epoch_loss

            if is_new_best:
                best_epoch_loss = epoch_val_losses['total'][0]

            val_losses.append(epoch_val_losses['total'])

            log_string = f'val results[{config["training"]["loss_coef"]}*our_cd]:\n'
            for k, v in epoch_val_losses.items():
                log_string += k + ': ' + str(v) + '\n'

            if is_new_best:
                log_string += "new best epoch"

            log.info(log_string)

            val_plots = []
            for cat_name, sample in epoch_val_samples.items():
                val_plots.append(
                    save_plot(sample[0], epoch, cat_name, samples_path,
                              'val_existing'))
                val_plots.append(
                    save_plot(sample[2], epoch, cat_name, samples_path,
                              'val_rec'))
                val_plots.append(
                    save_plot(sample[1].T, epoch, cat_name, samples_path,
                              'val_gt'))

            if config['telegram_logger']['enable']:
                chosen_plot_idx = np.random.choice(
                    np.arange(len(val_plots) / 3, dtype=np.int),
                    int(np.min([3, len(val_plots) / 3])),
                    replace=False)
                plots_to_log = []
                for idx in chosen_plot_idx:
                    plots_to_log.extend(val_plots[3 * idx:3 * idx + 3])
                tg_log.log_images(plots_to_log, log_string)

            if (epoch % config['training']['state_save_frequency'] == 0 or is_new_best) \
                    and epoch > config['training'].get('min_save_epoch', 0):
                torch.save(full_model.state_dict(),
                           join(weights_path, f'{epoch:05}_model.pth'))
                torch.save(optimizer.state_dict(),
                           join(weights_path, f'{epoch:05}_O.pth'))
                torch.save(scheduler.state_dict(),
                           join(weights_path, f'{epoch:05}_S.pth'))

                np.save(join(metrics_path, f'{epoch:05}_train'),
                        np.array(train_losses))
                np.save(join(metrics_path, f'{epoch:05}_val'),
                        np.array(val_losses))

                log_string = "Epoch: %s saved" % epoch
                log.debug(log_string)
                if config['telegram_logger']['enable']:
                    tg_log.log(log_string)

    elif run_mode == 'experiments':

        # from datasets.real_data import RealDataNPYDataset
        # test_dataset_dict = RealDataNPYDataset(root_dir="D:\\UJ\\bachelors\\3d-point-clouds-autocomplete\\data\\real_car_data")

        full_model.eval()

        with torch.no_grad():
            for experiment_name, experiment_dict in config['experiments'][
                    'settings'].items():
                if experiment_dict.pop('execute', False):
                    log.info(experiment_name)
                    experiment_functions_dict[experiment_name](
                        full_model, device, test_dataset_dict, result_dir_path,
                        latest_epoch, **experiment_dict)

    exit(0)
예제 #11
0
def evaluate_generativity(full_model: FullModel,
                          device,
                          datasets_dict,
                          results_dir,
                          epoch,
                          batch_size,
                          num_workers,
                          mean=0.0,
                          std=0.005):
    dataloaders_dict = {
        cat_name: DataLoader(cat_ds,
                             pin_memory=True,
                             batch_size=1,
                             num_workers=num_workers)
        for cat_name, cat_ds in datasets_dict.items()
    }
    chamfer_loss = ChamferLoss().to(device)
    with torch.no_grad():
        results = {}

        for cat_name, dl in dataloaders_dict.items():
            cat_gt = []
            for data in dl:
                _, missing, _, _ = data
                missing = missing.to(device)
                cat_gt.append(missing)
            cat_gt = torch.cat(cat_gt).contiguous()

            cat_results = {}

            for data in tqdm(dl, total=len(dl)):
                existing, _, _, _ = data
                existing = existing.to(device)

                obj_recs = []

                for j in range(len(cat_gt)):
                    fixed_noise = torch.zeros(
                        1, full_model.get_noise_size()).normal_(
                            mean=mean, std=std).to(device)
                    reconstruction = full_model(existing,
                                                None, [1, 2048, 3],
                                                epoch,
                                                device,
                                                noise=fixed_noise)

                    pc = reconstruction.cpu().detach().numpy()[0]
                    obj_recs.append(
                        torch.from_numpy(pc.T[
                            pc[1].argsort()[:1024]]).unsqueeze(0).to(device))

                obj_recs = torch.cat(obj_recs)

                for k, v in compute_all_metrics(obj_recs, cat_gt, batch_size,
                                                chamfer_loss).items():
                    cat_results[k] = cat_results.get(k, 0.0) + v.item()
                cat_results['jsd'] = cat_results.get(
                    'jsd', 0.0) + jsd_between_point_cloud_sets(
                        obj_recs.cpu().detach().numpy(),
                        cat_gt.cpu().numpy())
            results[cat_name] = cat_results
            print(cat_name, cat_results)

        with open(join(results_dir, 'evaluate_generativity',
                       str(epoch) + 'eval_gen_by_cat.json'),
                  mode='w') as f:
            json.dump(results, f)