示例#1
0
文件: main.py 项目: nftqcd/fthmc
def run_fthmc(
        flow: nn.ModuleList,
        config: TrainConfig,
        lfconfig: lfConfig,
        xi: torch.Tensor = None,
        nprint: int = 50,
        nplot: int = 10,
        window: int = 1,
        num_trajs: int = 1024,
        **kwargs,
):
    logger.rule(f'Running `ftHMC` using trained flow for {num_trajs} trajs')
    if torch.cuda.is_available():
        flow.to('cuda')

    flow.eval()

    ft = FieldTransformation(flow=flow, config=config, lfconfig=lfconfig)
    logdir = config.logdir
    ftstr = lfconfig.uniquestr()
    fthmcdir = os.path.join(logdir, 'ftHMC', ftstr)
    pdir = os.path.join(fthmcdir, 'plots')
    sdir = os.path.join(fthmcdir, 'summaries')
    writer = SummaryWriter(log_dir=sdir)
    history = ft.run(x=xi, nprint=nprint, nplot=nplot, window=window,
                     num_trajs=num_trajs, writer=writer, plotdir=pdir,
                     **kwargs)
    histfile = os.path.join(fthmcdir, 'history.z')
    logger.log(f'Saving history to: {histfile}')
    joblib.dump(history, histfile)

    return {'field_transformation': ft, 'history': history}
示例#2
0
    def __init__(self, num_classes: int, base_net: nn.ModuleList, source_layer_indexes: List[int],
                 extras: nn.ModuleList, classification_headers: nn.ModuleList,
                 regression_headers: nn.ModuleList, is_test=False, config=None, device=None):
        """Compose a SSD model using the given components.
        """
        super(SSD, self).__init__()
        self.device = device
        self.num_classes = num_classes
        self.base_net = base_net.to(self.device)
        self.source_layer_indexes = source_layer_indexes
        self.extras = extras
        self.classification_headers = classification_headers
        self.regression_headers = regression_headers
        self.is_test = is_test
        self.config = config

        self.source_layer_add_ons = nn.ModuleList([t[1] for t in source_layer_indexes
                                                   if isinstance(t, tuple) and not isinstance(t, GraphPath)])



        if is_test:
            self.config = config
            self.priors = config.priors.to(self.device)