Exemplo n.º 1
0
    def test_solver_gen_step_poster(self, solver):
        if solver.has_posterior_phase:
            data = solver.train_data
            liGAN.set_random_seed(0)

            if solver.learn_recon_var:
                gen_log_var0 = solver.gen_model.log_recon_var.item()
                if solver.has_prior_model:
                    prior_log_var0 = solver.prior_model.log_recon_var.item()

            metrics0 = solver.gen_step(grid_type='poster')
            liGAN.set_random_seed(0)
            _, metrics1 = solver.gen_forward(data, grid_type='poster')
            assert metrics1['loss'] < metrics0['loss'], 'loss did not decrease'

            if solver.has_prior_model:
                assert (
                    metrics1['recon2_loss'] + metrics1['kldiv2_loss'] < \
                    metrics0['recon2_loss'] + metrics0['kldiv2_loss']
                ), 'prior model loss did not decrease'

            if solver.learn_recon_var:
                assert (solver.gen_model.log_recon_var !=
                        gen_log_var0), 'gen log_recon_var did not change'
                if solver.has_prior_model:
                    assert (solver.prior_model.log_recon_var != prior_log_var0
                            ), 'prior log_recon_var did not change'
Exemplo n.º 2
0
def main(argv):
    args = parse_args(argv)

    with open(args.config_file) as f:
        config = yaml.safe_load(f)

    device = 'cuda'
    liGAN.set_random_seed(config.get('random_seed', None))

    generator_type = config.get('model_type', None) or 'Molecule'
    generator_type = getattr(
        liGAN.generating, generator_type + 'Generator'
    )
    generator = generator_type(
        out_prefix=config['out_prefix'],
        n_samples=config['generate']['n_samples'],
        fit_atoms=config['generate'].get('fit_atoms', True),
        data_kws=config['data'],
        gen_model_kws=config.get('gen_model', {}),
        prior_model_kws=config.get('prior_model', {}),
        atom_fitting_kws=config.get('atom_fitting', {}),
        bond_adding_kws=config.get('bond_adding', {}),
        output_kws=config['output'],
        device='cuda',
        verbose=config['verbose'],
        debug=args.debug,
    )
    generator.generate(**config['generate'])
    print('Done')
Exemplo n.º 3
0
 def test_solver_disc_step_prior(self, solver):
     if solver.has_disc_model and solver.has_prior_phase:
         data = solver.train_data
         liGAN.set_random_seed(0)
         metrics0 = solver.disc_step(grid_type='prior')
         liGAN.set_random_seed(0)
         _, metrics1 = solver.disc_forward(data, grid_type='prior')
         assert metrics1['loss'] < metrics0['loss'], 'loss did not decrease'
Exemplo n.º 4
0
def main(argv):
    ob.obErrorLog.SetOutputLevel(0)
    args = parse_args(argv)

    with open(args.config_file) as f:
        config = yaml.safe_load(f)

    if 'wandb' in config and 'use_wandb' not in config['wandb']:
        raise Exception('use_wandb must be included in wandb configs')

    if 'wandb' in config and config['wandb']['use_wandb']:
        import wandb
        if 'init_kwargs' in config['wandb']:
            wandb.init(settings=wandb.Settings(start_method="fork"),
                       config=config,
                       **config['wandb']['init_kwargs'])
        else:
            wandb.init(settings=wandb.Settings(start_method="fork"),
                       config=config)
        if 'out_prefix' not in config:
            try:
                os.mkdir('wandb_output')
            except FileExistsError:
                pass
            config['out_prefix'] = 'wandb_output/' + wandb.run.id
            sys.stderr.write('Setting output prefix to {}\n'.format(
                config['out_prefix']))

    device = 'cuda'
    liGAN.set_random_seed(config.get('random_seed', None))

    solver_type = getattr(liGAN.training, config['model_type'] + 'Solver')
    solver = solver_type(data_kws=config['data'],
                         wandb_kws=config.get('wandb', {'use_wandb': False}),
                         gen_model_kws=config['gen_model'],
                         disc_model_kws=config.get('disc_model', {}),
                         prior_model_kws=config.get('prior_model', {}),
                         loss_fn_kws=config['loss_fn'],
                         gen_optim_kws=config['gen_optim'],
                         disc_optim_kws=config.get('disc_optim', {}),
                         prior_optim_kws=config.get('prior_optim', {}),
                         atom_fitting_kws=config['atom_fitting'],
                         bond_adding_kws=config.get('bond_adding', {}),
                         out_prefix=config['out_prefix'],
                         device=device,
                         debug=args.debug,
                         sync_cuda=config.get('sync_cuda', False))

    if config['continue']:
        if config['continue'] is True:
            cont_iter = None
        else:
            cont_iter = config['continue']
        try:
            solver.load_state_and_metrics(cont_iter)
        except FileNotFoundError:
            pass

    if 'max_n_iters' in config:
        config['train']['max_iter'] = min(
            config['train']['max_iter'],
            solver.gen_iter + config['max_n_iters'])

    solver.train_and_test(**config['train'])