def __init__(self, config: str, pretrained: bool = True):
        super().__init__()
        self.cfg = parse_config(config)
        self.backbone = TextLSTM(BASE_DIR / self.cfg.BACKBONE.NET).cuda()
        if getattr(self.cfg.BACKBONE, 'WEIGHT', None):
            self.backbone.load_state_dict(
                torch.load(BASE_DIR / self.cfg.BACKBONE.WEIGHT))
        self.backbone.fc = nn.Identity()
        self.backbone.sfm = nn.Identity()

        profiling = EliteNet(BASE_DIR / self.cfg.PROFILING.NET).cuda()
        if getattr(self.cfg.BACKBONE, 'WEIGHT', None):
            profiling.load_state_dict(
                torch.load(BASE_DIR / self.cfg.PROFILING.WEIGHT))
        self.profiling = nn.Sequential(*list(profiling.children())[:-1])

        bottle_neck = self.backbone.cfg.HIDDEN_SIZE + profiling.cfg.FC4

        self.fc1 = nn.Linear(bottle_neck, 256, bias=True)
        self.bn1 = nn.BatchNorm1d(256)
        self.fc2 = nn.Linear(256, 512, bias=True)
        self.bn2 = nn.BatchNorm1d(512)
        self.out = nn.Linear(512, 2, bias=True)

        # freeze if pretrained flag is set
        if pretrained:
            _freeze(self.backbone)
            _freeze(self.profiling)
Пример #2
0
def get_config():
	"""	This function fetches the iView "config". Among other things,
		it tells us an always-metered "fallback" RTMP server, and points
		us to many of iView's other XML files.
	"""
	global iview_config

	iview_config = parser.parse_config(maybe_fetch(config.config_url))
Пример #3
0
    def __init__(self, config):
        """Initialize a design instance with a condiguration object

        Args:
            config (object): JSON configuration of the experimental design
        """

        parsed_config = parse_config(config)

        self.between_subject_factors = parsed_config.get(
            FactorType.between_subject.name, [])
        self.within_subject_factors = parsed_config.get(
            FactorType.within_subject.name, [])
Пример #4
0
    def __init__(self, config):
        super().__init__()
        # get config
        self.cfg = parse_config(config)

        # structure
        self.input = nn.Linear(self.cfg.INPUT, self.cfg.FC1, bias=True)
        self.fc1 = nn.Linear(self.cfg.FC1, self.cfg.FC2, bias=True)
        self.bn_1 = nn.BatchNorm1d(self.cfg.FC2)
        self.drop_1 = nn.Dropout(self.cfg.DROP1)
        self.fc2 = nn.Linear(self.cfg.FC2, self.cfg.FC3, bias=True)
        self.bn_2 = nn.BatchNorm1d(self.cfg.FC3)
        self.drop_2 = nn.Dropout(self.cfg.DROP2)
        self.fc3 = nn.Linear(self.cfg.FC3, self.cfg.FC4, bias=True)
        self.out = nn.Linear(self.cfg.FC4, self.cfg.OUTPUT, bias=True)
Пример #5
0
    def __init__(self, config):
        super().__init__()
        # load config
        self.cfg = parse_config(config)
        self.output_size = self.cfg.OUTPUT_SIZE
        self.hidden_size = self.cfg.HIDDEN_SIZE

        self.embedding_length = self.cfg.EMBEDDING_LENGTH
        self.word_embeddings = nn.Embedding.from_pretrained(
            torch.from_numpy(np.load(BASE_DIR /
                                     self.cfg.EMBEDDING_DIR)).float())
        self.lstm = nn.LSTM(self.cfg.EMBEDDING_LENGTH,
                            self.hidden_size,
                            bidirectional=True)
        self.fc = nn.Linear(self.hidden_size, self.output_size)
        self.sfm = nn.Softmax()
Пример #6
0
 def test_parse_reads_file(self):
     config = {
         "mc_version": "1.16.5",
         "mods_dir": "mods",
         "mods": [
             {
                 "source": "GITHUB",
                 "owner": "snallapa",
                 "repo": "scentfindermod"
             }
         ]
     }
     with open('config.json', 'w') as outfile:
         json.dump(config, outfile)
     
     def cleanup():
         os.remove("config.json")
     self.addCleanup(cleanup)
     sources = parse_config()
     self.assertEqual(1, len(sources))
     self.assertEqual("GithubSource", sources[0].__class__.__name__)
Пример #7
0
def get_config():
	"""	This function fetches the iView "config". Among other things,
		it tells us an always-metered "fallback" RTMP server, and points
		us to many of iView's other XML files.
	"""
	return parser.parse_config(fetch_url(config.config_url))
Пример #8
0
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger

from model import Net
from parser import parse_args, parse_config
from dataset import Dataset

if __name__ == "__main__":
    args = parse_args()
    args = parse_config(args)
    args.model_dir.mkdir(parents=True, exist_ok=True)
    args.checkpoints_dir.mkdir(exist_ok=True)

    if args.resume_checkpoint == 0:
        shutil.copy(args.cfg, args.model_dir / 'experiment_settings.cfg')
    else:
        shutil.copy(args.cfg,
                    args.model_dir / 'experiment_settings_resume.cfg')

    if args.log_file.exists():
        args.log_file.unlink()
    logger.add(args.log_file,
               format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}",
               backtrace=False,
               diagnose=False)
Пример #9
0
  02:       473         491         (200 iterations)
  03:       641         673         (300 iterations)
  04:       1001        1092        (1100 after 100 iterations)
  05:       749         801         400 iterations
  06:       876         113         200 iterations
  07:       885
  08:       4437        5362

  801 with 5 10 20 population profile on problem 05. 500 iterations.
  839 with 5 population profile on problem 05. 500 iterations.
  002 with 5 population profile on problem 05. 500 iterations. local optimum..

'''

depots, customers, durations, n_paths_per_depot = loader.load_dataset(filename)
conf = configparser.parse_config('configs/default.conf')

if len(plt.get_fignums()) > 0:
    ax0, ax1 = plt.gcf().get_axes()
else:
    _, (ax0, ax1) = plt.subplots(1, 2)

model = MDVRPModel(customers, depots, n_paths_per_depot, conf)
optimal_solution = utils.visualize_solution(model, solution_file)

model.evolve(3)
one = model.population[0]  # debug

L = [each.fitness_score() for each in model.population]
best = model.population[np.argmin(L)]