示例#1
0
 def download_file(self, file_id, dest_folder, image_id=0):
     params = {'id': file_id}
     download_url = utils.gs.get_selector(self.server, utils.gs.DOWNLOAD)
     r = requests.get(download_url,
                      headers=self.headers,
                      params=params,
                      stream=True,
                      cookies=self.cookies)
     name = findall('attachment; filename=\"(.*)\"',
                    r.headers['content-disposition'])
     if len(name) > 0 and name[0]:
         if image_id:
             file_name = os.path.join(
                 dest_folder, '%s_%s' % (str(image_id), str(name[0])))
         else:
             file_name = os.path.join(dest_folder, str(name[0]))
     else:
         Logger.log(LogLevel.ERROR, 'ERROR FILE NAME')
         return False
     try:
         with open(file_name, 'wb') as f:
             for chunk in r.iter_content(chunk_size=1024):
                 if chunk:  # filter out keep-alive new chunks
                     f.write(chunk)
                     f.flush()
     except IOError as e:
         Logger.log(LogLevel.ERROR, 'Save file IO_ERROR', e)
         return False
     return True
示例#2
0
 def print_array(array, text):
     array_size = len(array)
     if array_size > 0:
         Logger.log(LogLevel.INFO, '----------- %s %s files -----------' % (
             str(array_size), text))
         for one_file in array:
             Logger.log(LogLevel.INFO, str(one_file))
示例#3
0
    def query_directory(self, first_time):
        query_url = utils.gs.get_selector(self.server, utils.gs.QUERY)
        payload = self.payload.create_query_payload(self.pending)
        Logger.log(LogLevel.DEBUG, payload)

        try:
            if first_time:
                resp = requests.post(query_url,
                                     data=payload,
                                     headers=self.headers,
                                     verify=self.verify)
                self.set_cookie(resp)
                Logger.log(LogLevel.DEBUG, 'set cookie=', self.cookies)
            else:
                Logger.log(LogLevel.DEBUG, 'cookie=', self.cookies)
                resp = requests.post(query_url,
                                     data=payload,
                                     headers=self.headers,
                                     cookies=self.cookies,
                                     verify=self.verify)
            if not self.handle_response(resp, first_time):
                return False
        except IOError as e:
            Logger.log(LogLevel.ERROR, 'IO_ERROR', e)
            return False
        return True
示例#4
0
    def handle_response(self, json_response, first_time=False):

        if json_response.status_code != 200:
            Logger.log(LogLevel.ERROR, json_response.status_code)
            Logger.log(LogLevel.ERROR, json_response.text)
            if json_response.status_code == 400:
                Logger.log(LogLevel.ERROR,
                           "Bad request: please fix and rerun the command")
                exit(-1)
            return False

        parse_json = json.loads(json_response.text)
        Logger.log(LogLevel.DEBUG,
                   json.dumps(parse_json, indent=4, sort_keys=True))
        response_list = parse_json[utils.gs.RESPONSE]

        if type(response_list) is not list:
            response_list = [response_list]

        for response_object in response_list:
            file_data = self.pending.get(response_object[utils.gs.MD5])
            if utils.gs.TE in file_data.features \
                    and utils.gs.TE in response_object:
                found = TeData.handle_te_response(file_data, response_object,
                                                  first_time)
                if found:
                    self.download_reports(response_object[utils.gs.TE])
            if utils.gs.TEX in file_data.features \
                    and utils.gs.TEX in response_object:
                TexData.handle_tex_response(file_data, response_object,
                                            first_time)
                if TexData.extracted_file_download_id:
                    extraction_id = TexData.extracted_file_download_id
                    if not self.download_tex_result(extraction_id):
                        Logger.log(LogLevel.ERROR,
                                   'Failed to download extraction_id:',
                                   extraction_id)
                        file_data.tex = TexData.error(
                            "Unable to download file_id=%s" % extraction_id)
                        return True
                    else:
                        file_data.tex = TexData.log(
                            "Cleaned file was downloaded successfully file_id= %s"
                            % extraction_id)
            if not file_data.features:
                self.finished.append(self.pending.pop(file_data.md5))

        return True
示例#5
0
    def upload_directory(self):

        # Use copy of the list for proper removal
        res = True
        for file_data in self.pending.values():
            if not file_data.upload:
                continue
            try:
                session = requests.Session()
                json_request = self.payload.create_upload_payload(file_data)

                Logger.log(LogLevel.DEBUG, json_request)
                upload_url = utils.gs.get_selector(self.server, utils.gs.UPLOAD)
                with open(file_data.file_path, 'rb') as f:
                    form = MultipartEncoder({
                        "request": json_request,
                        "file": f,
                    })
                    headers = self.headers
                    headers["Content-Type"] = form.content_type
                    resp = session.post(upload_url,
                                        headers=headers,
                                        data=form,
                                        cookies=self.cookies,
                                        verify=self.verify)

                    Logger.log(LogLevel.DEBUG, resp)
                    if not self.handle_response(resp):
                        raise Exception('Failed to handle upload response')

            except Exception as e:
                Logger.log(LogLevel.ERROR, 'Uploading Error', e)
                res = False
                continue
        return res
示例#6
0
    def __init__(self, scan_directory, file_path, file_name, api_key, server, reports_folder, tex_method, tex_folder,
                 features=DEFAULT_FEATURES,
                 reports=DEFAULT_REPORTS,
                 recursive=DEFAULT_RECURSIVE_EMULATION):
        """
        Setting the requested parameters and creating
        :param scan_directory: the requested directory
        :param file_path: the requested file path
        :param file_name: the requested file name
        :param api_key: API Key for the cloud service
        :param server: Check Point SandBlast Appliance ip address
        :param reports_folder: the folder which the reports will be save to
        :param tex_method: the method to be used with Thereat Extraction
        :param tex_directory: the folder which the Scrubbing attachments will be save to
        :param features: the requested features
        :param reports: type of reports
        :param recursive: find files in the requested directory recursively
        """
        if api_key:
            self.headers = {'Authorization': api_key}
        else:
            self.headers = {}
        self.reports_folder = reports_folder
        self.tex_folder = tex_folder
        if features:
            self.features = features
        else:
            self.features = DEFAULT_FEATURES
        self.payload = Payload(reports, tex_method)
        self.server = server
        self.verify = True

        try:
            if reports_folder and not os.path.exists(reports_folder):
                os.makedirs(reports_folder)
        except Exception as e:
            Logger.log(LogLevel.CRITICAL,
                       'failed to create the needed folders', e)

        max_files = DEFAULT_MAX_FILES
        Logger.log(LogLevel.INFO, 'Calculating hash of files ')
        if scan_directory:
            for root, subdir_list, file_list in os.walk(ur'%s' % scan_directory):
                for fn in file_list:
                    if max_files == 0:
                        Logger.log(LogLevel.INFO,
                                   'Max of %d files' % DEFAULT_MAX_FILES)
                        break
                    else:
                        max_files -= 1
                    if os.path.isfile(os.path.join(root, fn)):
                        file_data = FileData(fn.encode('utf-8'), os.path.join(root, fn), list(self.features))
                        file_data.compute_hashes()
                        self.pending[file_data.md5] = file_data
                if not recursive or max_files == 0:
                    break
示例#7
0
def main():
    parser = argparse.ArgumentParser(
        description='Threat Prevention API example')

    files_argument_group = parser.add_mutually_exclusive_group(
        required=not IS_ONLY_GENERATE_TOKEN)
    files_argument_group.add_argument('-D',
                                      '--directory',
                                      help='The scanning directory')
    files_argument_group.add_argument('-fp',
                                      '--file_path',
                                      help='Path to file')

    parser.add_argument('-fn',
                        '--file_name',
                        help='File Name, relevant when file path supplied')
    parser.add_argument(
        '-R',
        '--recursive',
        action='store_true',
        help=
        'Emulate the files in the directory recursively, relevant when scanning directory supplied'
    )

    server_argument_group = parser.add_mutually_exclusive_group(
        required=not IS_ONLY_GENERATE_TOKEN)
    server_argument_group.add_argument('-k', '--key', help='API key')
    server_argument_group.add_argument('-e',
                                       '--sandblast_appliance',
                                       help='Check Point SandBlast Appliance')
    server_argument_group.add_argument(
        '-ci',
        '--client_id',
        nargs=2,
        metavar=('CLIENT_ID', 'ACCESS_KEY'),
        help=
        'Client ID and Access key, used for JWT token authenticated requests')

    parser.add_argument(
        '-gt',
        '--generate_token',
        action='store_true',
        help='Only create the JWT token without sending a request')
    parser.add_argument('-d',
                        '--debug',
                        action='store_true',
                        help='Add debugging')

    blades_info = parser.add_argument_group('Blades info')
    blades_info.add_argument('-t',
                             '--te',
                             action='store_true',
                             help='Activate Threat Emulation')
    blades_info.add_argument(
        '--tex',
        action='store_true',
        help='Activate Threat Extraction (supported only with cloud)')
    blades_info.add_argument(
        '--tex_folder',
        help=
        'A folder to download the Scrubbing attachments (required when TEX is active)'
    )
    blades_info.add_argument(
        '-m',
        '--tex_method',
        choices=['convert', 'clean'],
        default='convert',
        help='Scrubbing method. Convert to PDF / CleanContent')

    reports_section = parser.add_argument_group('Reports info',
                                                'Download Reports')
    reports_section.add_argument(
        '-r',
        '--reports',
        help='A folder to download the reports to (required for cloud)',
        required=False)
    reports_section.add_argument(
        '-p',
        '--pdf',
        action='store_true',
        help='Download PDF reports',
    )
    reports_section.add_argument(
        '-x',
        '--xml',
        action='store_true',
        help='Download XML reports',
    )
    reports_section.add_argument(
        '-s',
        '--summary',
        action='store_true',
        help='Download summary reports',
    )
    args = parser.parse_args()

    Logger.level = LogLevel.DEBUG if args.debug else LogLevel.INFO

    # Asking the API to enable features and reports according
    # to what was required by the user.
    features = []
    reports = []
    server = ""
    key = ""
    client_id = ""
    access_key = ""
    file_path = ""
    file_name = ""
    directory = ""

    args.te and features.append('te')
    args.tex and features.append('extraction')

    if (args.summary and args.pdf):
        parser.error(
            "Illegal request. Pdf reports are not available in the new Threat Emulation reports format. Requesting for pdf and summary reports simultaneously is not supported."
        )
        exit(-1)

    args.xml and reports.append('xml')
    args.pdf and reports.append('pdf')
    args.summary and reports.append('summary')

    # Verify the user values
    if len(reports) and not args.reports:
        parser.error("Please supply a reports directory")
        exit(-1)

    if args.key:
        key = args.key
        if not args.reports:
            parser.error("API Key supplied, please supply a reports folder")
            exit(-1)

    elif args.client_id:
        client_id = args.client_id[0]
        access_key = args.client_id[1]
        if not args.generate_token and not args.reports:
            parser.error("API Token supplied, please supply a reports folder")
            exit(-1)

    elif args.sandblast_appliance:
        if args.tex:
            Logger.log(
                LogLevel.ERROR,
                'TEX is not supported with Check Point SandBlast Appliance')
            features.remove('extraction')
        server = args.sandblast_appliance

    if args.tex:
        if not args.tex_folder:
            parser.error("TEX is active, please supply a tex folder")
            exit(-1)
        if not os.path.isdir(args.tex_folder):
            Logger.log(LogLevel.ERROR, 'Invalid tex folder as input')
            exit(-1)

    if not args.generate_token:
        if args.directory:
            if not os.path.isdir(args.directory):
                Logger.log(LogLevel.ERROR,
                           'Invalid scanning directory in input')
                exit(-1)
            directory = args.directory
        else:
            file_path = args.file_path.encode('utf-8')
            if args.file_name and args.file_name != 0:
                file_name = args.file_name.encode('utf-8')
            else:
                file_name = os.path.basename(file_path)
            if not os.path.isfile(args.file_path):
                Logger.log(LogLevel.ERROR,
                           'Invalid file path in input (%s)' % args.file_path)
                exit(-1)

    api = Run(directory, file_path, file_name, key, client_id, access_key,
              args.generate_token, server, args.reports, args.tex_method,
              args.tex_folder, features, reports, args.recursive)

    if not api.is_pending_files():
        Logger.log(LogLevel.INFO, 'The directory is empty')
        exit(0)

    if directory:
        Logger.log(
            LogLevel.INFO, 'Querying %d files from directory: %s' %
            (len(api.pending), args.directory))
    else:
        Logger.log(LogLevel.INFO, 'Querying file: %s ' % (file_path))

    api.query_directory(True)
    api.print_arrays_status()

    if api.is_pending_files():
        Logger.log(LogLevel.INFO, 'UPLOADING'),
        api.upload_directory()
        api.print_arrays_status()

    max_tries = MAX_TRIES
    while api.is_pending_files() and max_tries > 0:
        time.sleep(WAITING_SEC)
        api.query_directory(False)
        api.print_arrays_status()
        max_tries -= 1

    api.print_arrays()

    ret = api.get_final_status()
    print("return {}".format(ret))

    exit(ret)
示例#8
0
def run_simulation(**kwargs):
    kp = KwargsParser(kwargs, DEFAULTS)
    folder = Path(kp.folder).expanduser()
    folder.mkdir(exist_ok=True, parents=True)

    file_str = f'L_{kp.L}_g_{kp.g}_chi_{kp.chi}_dt_{kp.dt}_quench_{kp.quench}'
    if kp.task_id:
        file_str += f'_{kp.task_id}'
    logger = Logger(folder.joinpath(file_str + '.log'), True)
    opt_logger = Logger(folder.joinpath(file_str + '.opt.log'), True)
    outfile = folder.joinpath(file_str + '.pkl')
    kp.log(logger)

    opt_opts = dict(display_fun=get_display_fun(opt_logger),
                    line_search_fn='strong_wolfe',
                    max_iter=kp.max_iter,
                    tolerance_grad=kp.tolerance_grad)
    cont_opts = dict(contraction_method='brute')

    model = TFIM(kp.g,
                 bc='obc',
                 lx=kp.L,
                 ly=kp.L,
                 dtype_hamiltonian=np.float64)
    evolver = TimeEvolution(kp.g,
                            kp.dt,
                            'obc',
                            real_time=True,
                            lx=kp.L,
                            ly=kp.L,
                            pepo_dtype=np.complex128)

    logger.log(f'Starting with groundstate of g={kp.g} TFIM')

    # Prepare groundstate

    gs = None
    gs_energy = None

    if kp.gs_file:
        logger.log('GS file specified, loading GS from file')
        try:
            with open(kp.gs_file, 'rb') as f:
                res = pickle.load(f)
            gs_tensors = res['gs_tensors']
            gs = Peps(gs_tensors, 'obc')
            gs_energy = res['gs_energy']

            assert np.allclose(kp.g, res['kwargs']['g'])
            assert gs.lx == kp.L
            assert gs.ly == kp.L
        except Exception as e:
            logger.log('Failed to load GS from file. Error: ' + str(e))

    if (gs is None) or (gs_energy is None):
        logger.log('No GS file specified, optimising gs...')
        gs, gs_energy = model.groundstate(kp.chi, (kp.L, kp.L), 'ps', 0.05,
                                          cont_opts, opt_opts)

        logger.log('Saving GS to ' +
                   str(folder.joinpath(file_str + '.gs.pkl')))
        results = dict(kwargs=kp.kwargs(), gs=gs, gs_energy=gs_energy, g=kp.g)
        with open(folder.joinpath(file_str + '.gs.pkl'), 'wb') as f:
            pickle.dump(results, f)

    # Prepare quench

    if kp.quench == 'X':  # <Sx(r,t) Sx(center,0)>
        quench_operator = sx
        measure_operator = sx
    elif kp.quench == 'Y':  # <Sy(r,t) Sy(center,0)>
        quench_operator = sy
        measure_operator = sy
    elif kp.quench == 'Z':  # <Sz(r,t) Sz(center,0)>
        quench_operator = sz
        measure_operator = sz
    elif kp.quench == '+':  # <S+(r,t) S-(center,0)>
        quench_operator = sm
        measure_operator = sp
    else:
        raise ValueError(f'Illegal quench code {kp.quench}')

    logger.log(f'Quench: Applying quench operator to center site')
    quenched = SingleSiteOperator(quench_operator, kp.L // 2,
                                  kp.L // 2).apply_to_peps(gs)

    # Time evolution

    x_snapshot_data = onp.zeros([kp.n_steps + 1, kp.L, kp.L])
    y_snapshot_data = onp.zeros([kp.n_steps + 1, kp.L, kp.L])
    z_snapshot_data = onp.zeros([kp.n_steps + 1, kp.L, kp.L])
    correlator_data = onp.zeros([kp.n_steps + 1, kp.L, kp.L],
                                dtype=onp.complex)
    t_data = onp.zeros([kp.n_steps + 1])

    state = quenched
    opt_opts['dtype'] = np.complex128
    opt_opts['max_grad_evals_ls'] = 100
    for n in range(kp.n_steps):
        logger.log('Computing Observables')

        t = n * kp.dt
        x_snapshot_data[n, :, :] = x_snapshot(state, cont_opts)
        y_snapshot_data[n, :, :] = y_snapshot(state, cont_opts)
        z_snapshot_data[n, :, :] = z_snapshot(state, cont_opts)
        correlator_data[n, :, :] = correlator_timeslice(
            gs, state, measure_operator, gs_energy, t, **cont_opts)
        t_data[n] = t

        logger.log(f'Evolving to t={(n + 1) * kp.dt}')
        state = evolver.evolve(state,
                               contraction_options=cont_opts,
                               optimisation_options=opt_opts,
                               random_dev=None,
                               initial=kp.initial)

        # save results (will be overwritten), (in case process dies before it finishes)
        results = dict(kwargs=kp.kwargs(),
                       quench=kp.quench,
                       x_snapshot=x_snapshot_data,
                       y_snapshot=y_snapshot_data,
                       z_snapshot=z_snapshot_data,
                       correlator=correlator_data,
                       t=t_data,
                       state_tensors=state.get_tensors())
        with open(outfile, 'wb') as f:
            pickle.dump(results, f)

        if kp.save_all_peps:
            results = dict(kwargs=kp.kwargs(),
                           t=t,
                           state_tensors=state.get_tensors())
            with open(folder.joinpath(file_str + f'state_t_{t}.pkl'),
                      'wb') as f:
                pickle.dump(results, f)

    logger.log('Computing Observables')
    t = kp.n_steps * kp.dt
    x_snapshot_data[kp.n_steps, :, :] = x_snapshot(state, cont_opts)
    y_snapshot_data[kp.n_steps, :, :] = y_snapshot(state, cont_opts)
    z_snapshot_data[kp.n_steps, :, :] = z_snapshot(state, cont_opts)
    correlator_data[kp.n_steps, :, :] = correlator_timeslice(
        gs, state, measure_operator, gs_energy, t, **cont_opts)
    t_data[kp.n_steps] = t

    # save results
    logger.log(f'saving results to {outfile}')
    results = dict(kwargs=kp.kwargs(),
                   quench=kp.quench,
                   x_snapshot=x_snapshot_data,
                   y_snapshot=y_snapshot_data,
                   z_snapshot=z_snapshot_data,
                   correlator=correlator_data,
                   t=t_data,
                   state_tensors=state.get_tensors())
    with open(outfile, 'wb') as f:
        pickle.dump(results, f)

    if kp.save_all_peps:
        results = dict(kwargs=kp.kwargs(),
                       t=t,
                       state_tensors=state.get_tensors())
        with open(folder.joinpath(file_str + f'state_t_{t}.pkl'), 'wb') as f:
            pickle.dump(results, f)
示例#9
0
                val_events, val_bases = dataset.fetch_test_batch()
                val_events = Variable(val_events.transpose(0, 1).contiguous())
                val_bases = Variable(
                    val_bases.transpose(0, 1).contiguous().long())
                if cuda:
                    val_events = val_events.cuda()
                    val_bases = val_bases.cuda()
                val_nll = crf.neg_log_likelihood(val_events, val_bases)
                if print_viterbi:
                    vscore, vpaths = crf(val_events)
                    print("Viterbi score:")
                    print(vscore)
                    print("Viterbi paths:")
                    print(vpaths)
                logger.log(step, tr_nll.data[0], val_nll.data[0],
                           tr_nll.data[0] / batch_size,
                           val_nll.data[0] / batch_size)

            # serialize model occasionally:
            if step % save_every == 0: logger.save(step, crf)

            step += 1
            if step > max_iters: raise StopIteration

        del dataset
#--- handle keyboard interrupts:
except KeyboardInterrupt:
    del dataset
    logger.close()
    print("-" * 80)
    print("Halted training; reached {} training iterations.".format(step))
示例#10
0
 def print_status(array, text):
     array_size = len(array)
     if array_size > 0:
         Logger.log(LogLevel.INFO, '%s: %s files' % (text, str(array_size)))
示例#11
0
 def print_arrays_status(self):
     Logger.log(LogLevel.INFO, 'PROGRESS:')
     self.print_status(self.pending, 'Pending')
     self.print_status(self.error, 'Error')
     self.print_status(self.finished, 'Finished')
示例#12
0
def main(config, exp_dir, checkpoint=None):
    torch.manual_seed(config["random_seed"])
    np.random.seed(config["random_seed"])
    random.seed(config["random_seed"])

    logger = Logger(exp_dir)

    device = torch.device("cuda" if config["use_gpu"] else "cpu")

    train_loader, val_loader = get_data_loaders(config, device)

    model = get_model(config["model_name"], **config["model_args"]).to(device)
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=config["learning_rate"],
                                 weight_decay=config["weight_decay"])
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min")

    if "load_encoder" in config:
        encoder_model, _ = load_checkpoint(config["load_encoder"], device,
                                           get_model)
        model.encoder = encoder_model.encoder

    if checkpoint:
        logger.log("Resume training..")
        metrics = load_metrics(exp_dir)
        best_val_loss = checkpoint["best_val_loss"]
        i_episode = checkpoint["epoch"]
        model.load_state_dict(checkpoint["state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        scheduler.load_state_dict(checkpoint["scheduler"])
    else:
        i_episode = 0
        best_val_loss = float("inf")
        metrics = {
            "between_eval_time": AverageMeter(),
            "data_time": AverageMeter(),
            "batch_time": AverageMeter(),
            "train_losses": AverageMeter(),
            "train_accs": AverageMeter(),
            "val_time": AverageMeter(),
            "val_batch_time": AverageMeter(),
            "val_data_time": AverageMeter(),
            "val_losses": AverageMeter(),
            "val_accs": AverageMeter()
        }

    keep_training = True
    end = time.time()
    between_eval_end = time.time()
    while keep_training:
        for batch in train_loader:
            metrics["data_time"].update(time.time() - end)
            batch["slide"] = batch["slide"].to(device)

            model.train()
            optimizer.zero_grad()

            scores = compute_loss(config, model, batch, device)
            loss, acc = scores["loss"], scores["accuracy"]

            metrics["train_losses"].update(loss.item())
            metrics["train_accs"].update(acc)

            loss.backward()
            optimizer.step()
            metrics["batch_time"].update(time.time() - end)
            end = time.time()

            del acc
            del loss
            del batch
            if i_episode % config["eval_steps"] == 0:
                val_loss, val_acc = test(config, model, device, val_loader,
                                         metrics)
                scheduler.step(val_loss)

                metrics["between_eval_time"].update(time.time() -
                                                    between_eval_end)

                # Our optimizer has only one parameter group so the first
                # element of our list is our learning rate.
                lr = optimizer.param_groups[0]['lr']
                logger.log(
                    "Episode {0}\n"
                    "Time {metrics[between_eval_time].val:.3f} (data {metrics[data_time].val:.3f} batch {metrics[batch_time].val:.3f}) "
                    "Train loss {metrics[train_losses].val:.4e} ({metrics[train_losses].avg:.4e}) "
                    "Train acc {metrics[train_accs].val:.4f} ({metrics[train_accs].avg:.4f}) "
                    "Learning rate {lr:.2e}\n"
                    "Val time {metrics[val_time].val:.3f} (data {metrics[val_data_time].avg:.3f} batch {metrics[val_batch_time].avg:.3f}) "
                    "Val loss {metrics[val_losses].val:.4e} ({metrics[val_losses].avg:.4e}) "
                    "Val acc {metrics[val_accs].val:.4f} ({metrics[val_accs].avg:.4f}) "
                    .format(i_episode, lr=lr, metrics=metrics))

                save_metrics(metrics, exp_dir)

                is_best = val_loss < best_val_loss
                best_val_loss = val_loss if is_best else best_val_loss
                save_checkpoint(
                    {
                        "epoch": i_episode,
                        "model_name": config["model_name"],
                        "model_args": config["model_args"],
                        "state_dict": model.state_dict(),
                        "best_val_loss": best_val_loss,
                        "optimizer": optimizer.state_dict(),
                        "scheduler": scheduler.state_dict()
                    },
                    is_best,
                    path=exp_dir)
                end = time.time()
                between_eval_end = time.time()
                del val_loss
                del val_acc

            if i_episode >= config["num_episodes"]:
                keep_training = False
                break

            i_episode += 1
示例#13
0
def run_simulation(**kwargs):
    kp = KwargsParser(kwargs, DEFAULTS)
    folder = Path(kp.folder).expanduser()
    folder.mkdir(exist_ok=True, parents=True)
    logger = Logger(
        folder.joinpath(f'chi_{kp.chi}_Dopt_{kp.D_opt}_{kp.init}.log'), True)
    kp.log(logger)

    gs = None

    for g in kp.g_list:
        logger.lineskip()
        logger.log(f'g={g}')

        file_str = f'chi_{kp.chi}_Dopt_{kp.D_opt}_g_{g}_{kp.init}'
        opt_logger = Logger(folder.joinpath(file_str + '.opt.log'), True)
        outfile = folder.joinpath(file_str + '.pkl')

        opt_opts = dict(display_fun=get_display_fun(opt_logger),
                        line_search_fn='strong_wolfe',
                        max_iter=kp.max_iter,
                        dtype=np.float64)
        cont_opts = dict(chi_ctm=kp.D_opt)

        model = TFIM(g, bc='infinite', dtype_hamiltonian=np.float64)

        # initial state
        if kp.init == 'load':
            # find the file with the closest g
            g_closest = np.inf
            file_closest = None
            for f in os.listdir(folder):
                if f[-3:] != 'pkl':
                    continue
                start = f.rfind('_g_')
                if start == -1:
                    continue
                ends = [
                    f.rfind(f'_{init}') for init in ['ps', 'z+', 'x+', 'load']
                ]
                end = [e for e in ends if e != -1][0]
                _g = float(f[start + 3:end])
                if np.abs(g - _g) < np.abs(g -
                                           g_closest):  # closer then previous
                    g_closest = _g
                    file_closest = f

            # noinspection PyBroadException
            try:
                with open(folder.joinpath(file_closest), 'rb') as f:
                    results = pickle.load(f)
                init = results['gs']
                initial_noise = 0.
                print(f'loaded initial guess from {file_closest}')
            except Exception:
                if gs:  # if gs is available from previous loop iteration, use that
                    # failed to load even though this is not the first loop iteration
                    print('warning: loading from file failed', file=sys.stderr)
                    init = gs
                    initial_noise = 0.001
                else:  # if nothing found, use product state
                    init = 'ps'
                    initial_noise = 0.05
        elif kp.init == 'last':
            if gs:
                init = gs
                initial_noise = 0.
            else:
                init = 'ps'
                initial_noise = 0.05
        else:
            init = kp.init
            initial_noise = 0.05

        gs, gs_energy = model.groundstate(chi=kp.chi,
                                          initial_state=init,
                                          initial_noise=initial_noise,
                                          contraction_options=cont_opts,
                                          optimisation_options=opt_opts)

        en_list = []
        mag_list = []
        D_obs_list = []

        for D_obs in list(range(kp.D_opt))[10::10] + [kp.D_opt]:
            logger.log(f'D_obs={D_obs}')
            D_obs_list.append(D_obs)
            cont_opts['chi_ctm'] = D_obs
            en_list.append(model.energy(gs, **cont_opts))
            mag_list.append(
                LocalOperator(sx, np.array([[0]]),
                              hermitian=True).expval(gs, **cont_opts))

        print(f'saving results to {outfile}')

        results = dict(kwargs=kwargs,
                       g=g,
                       optimal_energy=gs_energy,
                       D_obs_list=D_obs_list,
                       en_list=en_list,
                       mag_list=mag_list,
                       logfile=str(opt_logger.logfile),
                       gs_tensors=gs.get_tensors())
        with open(outfile, 'wb') as f:
            pickle.dump(results, f)

        for D_obs in list(range(kp.D_opt,
                                2 * kp.D_opt))[10::10] + [2 * kp.D_opt]:
            logger.log(f'D_obs={D_obs}')
            D_obs_list.append(D_obs)
            cont_opts['chi_ctm'] = D_obs
            en_list.append(model.energy(gs, **cont_opts))
            mag_list.append(
                LocalOperator(sx, np.array([[0]]),
                              hermitian=True).expval(gs, **cont_opts))

        results = dict(kwargs=kwargs,
                       g=g,
                       optimal_energy=gs_energy,
                       D_obs_list=D_obs_list,
                       en_list=en_list,
                       mag_list=mag_list,
                       logfile=str(opt_logger.logfile),
                       gs_tensors=gs.get_tensors())

        with open(outfile, 'wb') as f:
            pickle.dump(results, f)