Example #1
0
def parse_configs():
    """Parse configs from JSON file."""
    parser = argparse.ArgumentParser(
        'L2HMC algorithm applied to a 2D U(1) lattice gauge model.')
    parser.add_argument("--log_dir",
                        dest="log_dir",
                        type=str,
                        default=None,
                        required=False,
                        help=("""Log directory to use from previous run.  If
                        this argument is not passed, a new directory will be
                        created."""))
    parser.add_argument("--json_file",
                        dest="json_file",
                        type=str,
                        default=None,
                        required=True,
                        help=("""Path to JSON file containing configs."""))
    args = parser.parse_args()
    with open(args.json_file, 'rt') as f:
        targs = argparse.Namespace()
        targs.__dict__.update(json.load(f))
        args = parser.parse_args(namespace=targs)

    flags = AttrDict(args.__dict__)
    for key, val in flags.items():
        if isinstance(val, dict):
            flags[key] = AttrDict(val)

    return args
Example #2
0
def train_hmc(flags):
    """Main method for training HMC model."""
    hflags = AttrDict(dict(flags).copy())
    lr_config = AttrDict(hflags.pop('lr_config', None))
    config = AttrDict(hflags.pop('dynamics_config', None))
    net_config = AttrDict(hflags.pop('network_config', None))
    hflags.train_steps = hflags.pop('hmc_steps', None)
    hflags.beta_init = hflags.beta_final

    config.update({
        'hmc': True,
        'use_ncp': False,
        'aux_weight': 0.,
        'zero_init': False,
        'separate_networks': False,
        'use_conv_net': False,
        'directional_updates': False,
        'use_scattered_xnet_update': False,
        'use_tempered_traj': False,
        'gauge_eq_masks': False,
    })

    hflags.profiler = False
    hflags.make_summaries = True

    lr_config = LearningRateConfig(
        warmup_steps=0,
        decay_rate=0.9,
        decay_steps=hflags.train_steps // 10,
        lr_init=lr_config.get('lr_init', None),
    )

    train_dirs = io.setup_directories(hflags, 'training_hmc')
    dynamics = GaugeDynamics(hflags, config, net_config, lr_config)
    dynamics.save_config(train_dirs.config_dir)
    x, train_data = train_dynamics(dynamics, hflags, dirs=train_dirs)
    if IS_CHIEF:
        output_dir = os.path.join(train_dirs.train_dir, 'outputs')
        io.check_else_make_dir(output_dir)
        train_data.save_data(output_dir)

        params = {
            'eps': dynamics.eps,
            'num_steps': dynamics.config.num_steps,
            'beta_init': hflags.beta_init,
            'beta_final': hflags.beta_final,
            'lattice_shape': dynamics.config.lattice_shape,
            'net_weights': NET_WEIGHTS_HMC,
        }
        plot_data(train_data,
                  train_dirs.train_dir,
                  hflags,
                  thermalize=True,
                  params=params)

    return x, train_data, dynamics.eps.numpy()
Example #3
0
def flatten_dict(d):
    """Recursively convert all entries of `d` to be `AttrDict`."""
    if not isinstance(d, AttrDict):
        d = AttrDict(**d)

    for key, val in d.items():
        if isinstance(val, dict):
            if not isinstance(val, AttrDict):
                d[key] = flatten_dict(val)
            else:
                d[key] = AttrDict(**val)

    return d
Example #4
0
def parse_test_configs(test_configs_file=None):
    if test_configs_file is None:
        test_configs_file = os.path.join(BIN_DIR, 'test_configs.json')

    with open(test_configs_file, 'rt') as f:
        test_flags = json.load(f)

    test_flags = AttrDict(dict(test_flags))
    for key, val in test_flags.items():
        if isinstance(val, dict):
            test_flags[key] = AttrDict(val)

    return test_flags
Example #5
0
    def train_step(self, data):
        """Perform a single training step."""
        x, beta = data
        start = time.time()
        with tf.GradientTape() as tape:
            states, accept_prob, sumlogdet = self((x, beta), training=True)
            loss = self.calc_losses(states, accept_prob)

            if self.aux_weight > 0:
                z = tf.random.normal(x.shape, dtype=x.dtype)
                states_, accept_prob_, _ = self((z, beta), training=True)
                loss_ = self.calc_losses(states_, accept_prob_)
                loss += loss_

        if NUM_RANKS > 1:
            tape = hvd.DistributedGradientTape(tape)

        grads = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.trainable_variables))

        metrics = AttrDict({
            'dt': time.time() - start,
            'loss': loss,
            'accept_prob': accept_prob,
            'eps': self.eps,
            'beta': states.init.beta,
            'sumlogdet': sumlogdet.out,
        })

        if self.optimizer.iterations == 0 and NUM_RANKS > 1:
            hvd.broadcast_variables(self.variables, root_rank=0)
            hvd.broadcast_variables(self.optimizer.variables(), root_rank=0)

        return states.out.x, metrics
Example #6
0
def _decode_cfg_value(v):
    """Decodes a raw config value (e.g., from a yaml config files or command
    line argument) into a Python object.
    """
    # Configs parsed from raw yaml will contain dictionary keys that need to be
    # converted to AttrDict objects
    if isinstance(v, dict):
        return AttrDict(v)
    # All remaining processing is only applied to strings
    if not isinstance(v, str):
        return v
    # Try to interpret `v` as a:
    #   string, number, tuple, list, dict, boolean, or None
    try:
        v = literal_eval(v)
    # The following two excepts allow v to pass through when it represents a
    # string.
    #
    # Longer explanation:
    # The type of v is always a string (before calling literal_eval), but
    # sometimes it *represents* a string and other times a data structure, like
    # a list. In the case that v represents a string, what we got back from the
    # yaml parser is 'foo' *without quotes* (so, not '"foo"'). literal_eval is
    # ok with '"foo"', but will raise a ValueError if given 'foo'. In other
    # cases, like paths (v = 'foo/bar' and not v = '"foo/bar"'), literal_eval
    # will raise a SyntaxError.
    except ValueError:
        pass
    except SyntaxError:
        pass
    return v
def main(exp_config):

    exp_config = AttrDict(exp_config)

    model = MHUnet(use_dropout=exp_config.use_dropout,
                   complementary=exp_config.complementary,
                   multitask=exp_config.multitask,
                   conditioning=exp_config.conditioning,
                   use_bias=exp_config.use_bias,
                   n_decoders=exp_config.num_decoders,
                   num_downs=exp_config.num_downblocks)

    if exp_config.dataset == 'urmp':
        ds = URMPSpec(dataset_dir=exp_config.dataset_dir,
                      context=bool(exp_config.conditioning))
    elif exp_config.dataset == 'solos':
        ds = SolosSpec('test',
                       data_dir=exp_config.dataset_dir,
                       load_specs=(exp_config.input_type == 'spec_load'),
                       context=bool(exp_config.conditioning))

    loader = torch.utils.data.DataLoader(ds,
                                         batch_size=1,
                                         shuffle=False,
                                         num_workers=0)

    pipeline = ModelActionPipeline(model=model,
                                   train_loader=loader,
                                   val_loader=loader,
                                   exp_config=exp_config)

    checkpoint_path = os.path.join(exp_config.dir_checkpoint,
                                   exp_config.model_checkpoint)
    pipeline.test_model(checkpoint_path, exp_config.output_dir)
Example #8
0
def loop_over_log_dirs():
    rld1 = os.path.join(BASE_DIR, 'gauge_logs_eager', '2020_07')
    rld2 = os.path.join(BASE_DIR, 'gauge_logs_eager', '2020_06')
    ld1 = [
        os.path.join(rld1, i) for i in os.listdir(rld1)
        if os.path.isdir(os.path.join(rld1, i))
    ]
    ld2 = [
        os.path.join(rld2, i) for i in os.listdir(rld2)
        if os.path.isdir(os.path.join(rld2, i))
    ]

    log_dirs = ld1 + ld2
    for log_dir in log_dirs:
        args = AttrDict({
            'hmc': False,
            'run_steps': 2000,
            'overwrite': True,
            'log_dir': log_dir,
        })

        try:
            run(args, log_dir, random_start=True)
        except:
            pass
Example #9
0
def multiple_runs(flags, json_file=None):
    default = (512, 16, 16, 2)
    #  run_steps = flags.run_steps if flags.run_steps is not None else 125000
    shape = flags.x_shape if flags.x_shape is not None else default

    num_steps = [5, 10]
    eps = [0.05, 0.1, 0.2]
    betas = [2., 3., 4., 5., 6., 7.]
    #  run_steps = [50000, 50000, 50000, 100000, 100000, 100000]
    #  betas = [5.0, 6.0, 7.0]
    #num_steps = [10, 15, 20, 25]
    #  eps = [0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3]
    #  eps = [0.1, 0.125, 0.15, 0.175, 0.2]

    #  skip_existing = not flags.overwrite
    for b in random.sample(betas, len(betas)):
        for ns in random.sample(num_steps, len(num_steps)):
            for e in random.sample(eps, len(eps)):
                run_steps = 50000 if b < 5. else 100000
                args = AttrDict({
                    'eps': e,
                    'beta': b,
                    'num_steps': ns,
                    'run_steps': run_steps,
                    'x_shape': shape,
                    'skip_existing': (not flags.overwrite),
                })
                #  if not flags.overwrite:
                if (not flags.overwrite):
                    exists = check_existing(b, ns, e)
                    if exists:
                        io.rule('Skipping existing run!')
                        continue

                _ = main(args, json_file=json_file)
Example #10
0
def build_test_dynamics():
    """Build quick test dynamics for debugging."""
    jfile = os.path.abspath(os.path.join(BIN_DIR, 'test_dynamics_flags.json'))
    with open(jfile, 'rt') as f:
        flags = json.load(f)
    flags = AttrDict(flags)
    return build_dynamics(flags)
Example #11
0
def load_hmc_flags():
    """Load HMC flags from `BIN_DIR/hmc_configs.json`."""
    cfg_file = os.path.join(BIN_DIR, 'hmc_configs.json')
    with open(cfg_file, 'rt') as f:
        flags = json.load(f)

    return AttrDict(flags)
Example #12
0
    def test_reversibility(self,
                           data: Union[tf.Tensor, List[tf.Tensor]],
                           training: bool = None):
        """Test reversibility.

        NOTE:
         1. Run forward then backward
                 (x, v) -> (xf, vf)
                 (xf, vf) -> (xb, vb)
            check that x == xb, v == vb

         2. Run backward then forward
                 (x, v) -> (xb, vb)
                 (xb, vb) -> (xf, vf)
            check that x == xf, v == vf
        """
        dxf, dvf = self._test_reversibility(data,
                                            forward_first=True,
                                            training=training)
        dxb, dvb = self._test_reversibility(data,
                                            forward_first=False,
                                            training=training)
        output = AttrDict({
            'dxf': dxf,
            'dvf': dvf,
            'dxb': dxb,
            'dvb': dvb,
        })

        return output
Example #13
0
def restore_flags(flags, train_dir):
    """Update `FLAGS` using restored flags from `log_dir`."""
    rf_file = os.path.join(train_dir, 'FLAGS.z')
    restored = AttrDict(dict(io.loadz(rf_file)))
    io.log(f'Restoring FLAGS from: {rf_file}...')
    flags.update(restored)

    return flags
Example #14
0
 def __init__(self, steps, header=None, dirs=None, print_steps=100):
     self.steps = steps
     self.print_steps = print_steps
     self.dirs = dirs
     self.data_strs = [header]
     self.steps_arr = []
     self.data = AttrDict(defaultdict(list))
     if dirs is not None:
         io.check_else_make_dir(
             [v for k, v in dirs.items() if 'file' not in k])
Example #15
0
def test_resume_training(log_dir: str):
    """Test restoring a training session from a checkpoint."""
    flags = AttrDict(
        dict(io.loadz(os.path.join(log_dir, 'training', 'FLAGS.z'))))

    flags.log_dir = log_dir
    flags.train_steps += flags.get('train_steps', 10)
    x, dynamics, train_data, flags = train(flags)
    beta = flags.get('beta', 1.)
    dynamics, run_data, x = run(dynamics, flags, x=x)

    return AttrDict({
        'x': x,
        'flags': flags,
        'log_dir': flags.log_dir,
        'dynamics': dynamics,
        'run_data': run_data,
        'train_data': train_data,
    })
Example #16
0
    def test_step(self, data):
        """Perform a single inference step."""
        start = time.time()
        states, data = self(data, training=False)
        accept_prob = data.get('accept_prob', None)
        ploss, qloss = self.calc_losses(states, accept_prob)
        loss = ploss + qloss

        metrics = AttrDict({
            'dt': time.time() - start,
            'loss': loss,
        })
        if self.plaq_weight > 0 and self.charge_weight > 0:
            metrics.update({'ploss': ploss, 'qloss': qloss})

        metrics.update({
            'accept_prob': accept_prob,
            'eps': self.eps,
            'beta': states.init.beta,
        })

        if self._verbose:
            metrics.update({
                'Hf_start':
                data.forward.energies[0],
                'Hf_mid':
                data.forward.energies[self.config.num_steps // 2],
                'Hf_end':
                data.forward.energies[-1],
                'Hb_start':
                data.backward.energies[0],
                'Hb_mid':
                data.backward.energies[self.config.num_steps // 2],
                'Hb_end':
                data.backward.energies[-1],
                'ld_f_start':
                data.forward.logdets[0],
                'ld_f_mid':
                data.forward.logdets[self.config.num_steps // 2],
                'ld_f_end':
                data.forward.logdets[-1],
                'ld_b_start':
                data.backward.logdets[0],
                'ld_b_mid':
                data.backward.logdets[self.config.num_steps // 2],
                'ld_b_end':
                data.backward.logdets[-1],
                #  'sumlogdet': sumlogdet.out,
            })

        observables = self.calc_observables(states)
        metrics.update(**observables)

        return states.out.x, metrics
Example #17
0
    def transition_kernel_directional(
        self,
        state: State,
        forward: bool,
        training: bool = None,
    ):
        """Implements a series of directional updates."""
        state_prop = State(state.x, state.v, state.beta)
        sumlogdet = tf.zeros((self.batch_size, ), dtype=TF_FLOAT)
        logdets = tf.TensorArray(TF_FLOAT,
                                 dynamic_size=True,
                                 size=self.batch_size,
                                 clear_after_read=True)
        energies = tf.TensorArray(TF_FLOAT,
                                  dynamic_size=True,
                                  size=self.batch_size,
                                  clear_after_read=True)
        # ====
        # Forward for first half of trajectory
        for step in range(self.config.num_steps // 2):
            if self._verbose:
                logdets = logdets.write(step, sumlogdet)
                energies = energies.write(step, self.hamiltonian(state_prop))

            state_prop, logdet = self._forward_lf(step, state_prop, training)
            sumlogdet += logdet

        # ====
        # Flip momentum
        state_prop = State(state_prop.x, -1. * state_prop.v, state_prop.beta)

        # ====
        # Backward for second half of trajectory
        for step in range(self.config.num_steps // 2, self.config.num_steps):
            state_prop, logdet = self._backward_lf(step, state_prop, training)
            sumlogdet += logdet

            logdets = logdets.write(step, logdet)
            energies = energies.write(step, self.hamiltonian(state_prop))

        accept_prob = self.compute_accept_prob(state, state_prop, sumlogdet)
        metrics = AttrDict({
            'sumlogdet': sumlogdet,
            'accept_prob': accept_prob,
        })
        if self._verbose:
            metrics.update({
                'energies':
                [energies.read(i) for i in range(self.config.num_steps)],
                'logdets':
                [logdets.read(i) for i in range(self.config.num_steps)],
            })

        return state_prop, metrics
Example #18
0
    def _transition_kernel_backward(self, state: State, training: bool = None):
        """Run the augmented leapfrog sampler in the forward direction."""
        kwargs = {
            'dynamic_size': True,
            'size': self.batch_size,
            'clear_after_read': True
        }
        logdets = tf.TensorArray(TF_FLOAT, **kwargs)
        energies = tf.TensorArray(TF_FLOAT, **kwargs)
        sumlogdet = tf.zeros((self.batch_size, ))
        state_prop = State(state.x, state.v, state.beta)

        state_prop, logdet = self._half_v_update_backward(
            state_prop, 0, training)
        sumlogdet += logdet
        for step in range(self.config.num_steps):
            if self._verbose:
                logdets = logdets.write(step, sumlogdet)
                energies = energies.write(step, self.hamiltonian(state_prop))

            state_prop, logdet = self._full_x_update_backward(
                state_prop, step, training)

            if step < self.config.num_steps - 1:
                state_prop, logdet = self._full_v_update_backward(
                    state_prop, step, training)
                sumlogdet += logdet

        state_prop, logdet = self._half_v_update_backward(
            state_prop, step, training)
        sumlogdet += logdet

        accept_prob = self.compute_accept_prob(state, state_prop, sumlogdet)

        metrics = AttrDict({
            'sumlogdet': sumlogdet,
            'accept_prob': accept_prob,
        })
        if self._verbose:
            logdets = logdets.write(self.config.num_steps, sumlogdet)
            energies = energies.write(self.config.num_steps,
                                      self.hamiltonian(state_prop))
            metrics.update({
                'energies':
                [energies.read(i) for i in range(self.config.num_steps)],
                'logdets':
                [logdets.read(i) for i in range(self.config.num_steps)],
            })

        return state_prop, metrics
Example #19
0
    def load_data(data_dir):
        """Load data from `data_dir` and populate `self.data`."""
        contents = os.listdir(data_dir)
        fnames = [i for i in contents if i.endswith('.z')]
        keys = [i.rstrip('.z') for i in fnames]
        data_files = [os.path.join(data_dir, i) for i in fnames]
        data = {}
        for key, val in zip(keys, data_files):
            if 'x_rank' in key:
                continue
            io.log(f'Restored {key} from {val}.')
            data[key] = io.loadz(val)

        return AttrDict(data)
Example #20
0
def test_single_network(flags: AttrDict):
    """Test training on single network."""
    flags.dynamics_config.separate_networks = False
    x, dynamics, train_data, flags = train(flags)
    beta = flags.get('beta', 1.)
    dynamics, run_data, x = run(dynamics, flags, x=x)

    return AttrDict({
        'x': x,
        'flags': flags,
        'log_dir': flags.log_dir,
        'dynamics': dynamics,
        'run_data': run_data,
        'train_data': train_data,
    })
Example #21
0
def multiple_runs():
    num_steps = 10
    run_steps = 5000
    betas = [2., 3., 4., 5., 6.]
    #  eps = [0.1, 0.125, 0.15, 0.175, 0.2]
    eps = [0.05, 0.075, 0.225, 0.25, 0.275]
    for b in betas:
        for e in eps:
            args = AttrDict({
                'eps': e,
                'beta': b,
                'num_steps': num_steps,
                'run_steps': run_steps
            })
            _ = main(args)
Example #22
0
    def calc_observables(self, states):
        """Calculate observables."""
        _, q_init_sin, q_init_proj = self._calc_observables(states.init)
        plaqs, q_out_sin, q_out_proj = self._calc_observables(states.out)
        dq_sin = tf.math.abs(q_out_sin - q_init_sin)
        dq_proj = tf.math.abs(q_out_proj - q_init_proj)

        observables = AttrDict({
            'dq': dq_proj,
            'dq_sin': dq_sin,
            'charges': q_out_proj,
            'plaqs': plaqs,
        })

        return observables
Example #23
0
def test_hmc_run(
    configs: dict[str, Any],
    make_plots: bool = True,
) -> TestOutputs:
    """Testing generic HMC."""
    logger.info(f'Testing generic HMC')
    t0 = time.time()
    configs = AttrDict(**dict(copy.deepcopy(configs)))
    configs['dynamics_config']['hmc'] = True
    #  hmc_dir = os.path.join(os.path.dirname(PROJECT_DIR),
    #                         'gauge_logs_eager', 'test', 'hmc_runs')
    hmc_dir = os.path.join(GAUGE_LOGS_DIR, 'hmc_test_logs')
    run_out = run_hmc(configs, hmc_dir=hmc_dir, make_plots=make_plots)

    logger.info(f'Passed! Took: {time.time() - t0:.4f} seconds')
    return TestOutputs(None, run_out)
Example #24
0
def build_dynamics(flags):
    """Build dynamics using configs from FLAGS."""
    lr_config = LearningRateConfig(**dict(flags.get('lr_config', None)))
    #  config = GaugeDynamicsConfig(**dict(flags.get('dynamics_config', None)))
    config = AttrDict(flags.get('dynamics_config', None))
    net_config = NetworkConfig(**dict(flags.get('network_config', None)))
    conv_config = None

    if config.get('use_conv_net', False):
        conv_config = flags.get('conv_config', None)
        input_shape = config.get('lattice_shape', None)[1:]
        conv_config.update({
            'input_shape': input_shape,
        })
        conv_config = ConvolutionConfig(**conv_config)

    return GaugeDynamics(flags, config, net_config, lr_config, conv_config)
Example #25
0
    def test_step(self, data):
        """Perform a single inference step."""
        x, beta = data
        start = time.time()
        states, accept_prob, sumlogdet = self((x, beta), training=False)
        loss = self.calc_losses(states, accept_prob)

        metrics = AttrDict({
            'dt': time.time() - start,
            'loss': loss,
            'accept_prob': accept_prob,
            'eps': self.eps,
            'beta': states.init.beta,
            'sumlogdet': sumlogdet.out,
        })

        return states.out.x, metrics
Example #26
0
    def transition_kernel(
        self,
        state: State,
        forward: bool,
        training: bool = None,
    ):
        """Transition kernel of the augmented leapfrog integrator."""
        lf_fn = self._forward_lf if forward else self._backward_lf

        state_prop = State(x=state.x, v=state.v, beta=state.beta)
        sumlogdet = tf.zeros((self.batch_size, ), dtype=TF_FLOAT)
        logdets = tf.TensorArray(TF_FLOAT,
                                 dynamic_size=True,
                                 size=self.batch_size,
                                 clear_after_read=True)
        energies = tf.TensorArray(TF_FLOAT,
                                  dynamic_size=True,
                                  size=self.batch_size,
                                  clear_after_read=True)

        for step in range(self.config.num_steps):
            if self._verbose:
                logdets = logdets.write(step, sumlogdet)
                energies = energies.write(step, self.hamiltonian(state_prop))

            state_prop, logdet = lf_fn(step, state_prop, training)
            sumlogdet += logdet

        accept_prob = self.compute_accept_prob(state, state_prop, sumlogdet)
        metrics = AttrDict({
            'sumlogdet': sumlogdet,
            'accept_prob': accept_prob,
        })
        if self._verbose:
            logdets = logdets.write(self.config.num_steps, sumlogdet)
            energies = energies.write(self.config.num_steps,
                                      self.hamiltonian(state_prop))
            metrics.update({
                'energies':
                [energies.read(i) for i in range(self.config.num_steps)],
                'logdets':
                [logdets.read(i) for i in range(self.config.num_steps)],
            })

        return state_prop, metrics
Example #27
0
    def load_data(data_dir: Union[str, Path]):
        """Load data from `data_dir` and populate `self.data`."""
        # TODO: Update to use h5py for `.hdf5` files
        contents = os.listdir(data_dir)
        fnames = [i for i in contents if i.endswith('.z')]
        keys = [i.rstrip('.z') for i in fnames]
        data_files = [os.path.join(data_dir, i) for i in fnames]
        data = {}
        for key, val in zip(keys, data_files):
            if 'x_rank' in key:
                continue

            data[key] = io.loadz(val)
            if VERBOSE:
                logger.info(f'Restored {key} from {val}.')


        return AttrDict(data)
Example #28
0
def main(exp_config):

    exp_config = AttrDict(exp_config)

    # Select Dataset
    if exp_config.dataset == 'urmp':
        if exp_config.with_resnet:
            ds = URMPMM(dataset_dir=exp_config.dataset_dir,
                        context=bool(exp_config.conditioning),
                        n_visual_frames=exp_config.n_visual_frames)
        else:
            ds = URMPMMFeatures(dataset_dir=exp_config.dataset_dir,
                                context=bool(exp_config.conditioning),
                                n_visual_frames=exp_config.n_visual_frames)
    elif exp_config.dataset == 'solos':
        if not exp_config.with_resnet:
            ds = SolosMMFeatures('test',
                                 data_dir=exp_config.dataset_dir,
                                 n_mix_max=7,
                                 context=bool(exp_config.conditioning),
                                 n_visual_frames=exp_config.n_visual_frames)
        else:
            ds = SolosMM('test',
                         data_dir=exp_config.dataset_dir,
                         n_mix_max=7,
                         context=bool(exp_config.conditioning),
                         n_visual_frames=exp_config.n_visual_frames)

    loader = torch.utils.data.DataLoader(ds,
                                         batch_size=1,
                                         shuffle=False,
                                         num_workers=0)

    # Select Multimodal Model Conditioned on Labels of Features
    model = FeatureConditioned(exp_config)

    pipeline = ModelActionPipeline(model=model,
                                   train_loader=loader,
                                   val_loader=loader,
                                   exp_config=exp_config)

    checkpoint_path = os.path.join(exp_config.dir_checkpoint,
                                   exp_config.model_checkpoint)
    pipeline.test_model(checkpoint_path, exp_config.output_dir)
Example #29
0
def test_separate_networks(flags: AttrDict):
    """Test training on separate networks."""
    flags.hmc_steps = 0
    #  flags.log_dir = None
    flags.log_dir = io.make_log_dir(flags, 'GaugeModel', LOG_FILE)

    flags.dynamics_config.separate_networks = True
    flags.compile = False
    x, dynamics, train_data, flags = train(flags)
    beta = flags.get('beta', 1.)
    dynamics, run_data, x = run(dynamics, flags, x=x)

    return AttrDict({
        'x': x,
        'flags': flags,
        'log_dir': flags.log_dir,
        'dynamics': dynamics,
        'run_data': run_data,
        'train_data': train_data,
    })
Example #30
0
def main(exp_config):
    # for convenient attribute access
    exp_config = AttrDict(exp_config)

    # Select Multimodal Model Conditioned on Labels of Features
    model = FeatureConditioned(exp_config)

    # Select Dataset
    if not exp_config.with_resnet:
        loaders = [
            torch.utils.data.DataLoader(SolosMMFeatures(
                data_type,
                data_dir=exp_config.dataset_dir,
                n_mix_max=2 if exp_config.curriculum_training else 7,
                context=bool(exp_config.conditioning),
                n_visual_frames=exp_config.n_visual_frames),
                                        batch_size=exp_config.batch_size,
                                        shuffle=shuffle,
                                        num_workers=8)
            for data_type, shuffle in zip(['train', 'test'], [True, False])
        ]
    else:
        loaders = [
            torch.utils.data.DataLoader(SolosMM(
                data_type,
                data_dir=exp_config.dataset_dir,
                n_mix_max=2 if exp_config.curriculum_training else 7,
                context=bool(exp_config.conditioning),
                n_visual_frames=exp_config.n_visual_frames),
                                        batch_size=exp_config.batch_size,
                                        shuffle=shuffle,
                                        num_workers=8)
            for data_type, shuffle in zip(['train', 'test'], [True, False])
        ]

    pipeline = ModelActionPipeline(model=model,
                                   train_loader=loaders[0],
                                   val_loader=loaders[1],
                                   exp_config=exp_config)

    pipeline.train_model()