Exemplo n.º 1
0
def get_train_loader(batch_size,
                     mu,
                     n_iters_per_epoch,
                     L,
                     root='dataset',
                     seed=0):
    data_x, label_x, data_u, label_u = load_data_train(L=L,
                                                       dspth=root,
                                                       seed=seed)

    ds_x = Cifar10(data=data_x, labels=label_x, is_train=True)
    sampler_x = RandomSampler(ds_x,
                              replacement=True,
                              num_samples=n_iters_per_epoch * batch_size)
    batch_sampler_x = BatchSampler(sampler_x, batch_size, drop_last=True)
    dl_x = torch.utils.data.DataLoader(ds_x,
                                       batch_sampler=batch_sampler_x,
                                       num_workers=1,
                                       pin_memory=True)
    ds_u = Cifar10(data=data_u, labels=label_u, is_train=True)
    sampler_u = RandomSampler(ds_u,
                              replacement=True,
                              num_samples=mu * n_iters_per_epoch * batch_size)
    batch_sampler_u = BatchSampler(sampler_u, batch_size * mu, drop_last=True)
    dl_u = torch.utils.data.DataLoader(ds_u,
                                       batch_sampler=batch_sampler_u,
                                       num_workers=2,
                                       pin_memory=True)
    return dl_x, dl_u
Exemplo n.º 2
0
Arquivo: cifar.py Projeto: MLDL/FROST
def get_train_loader(batch_size, mu, mu_c, n_iters_per_epoch, L, root='dataset', seed=0, name=None):
    if name == None:
        name = "dataset/seeds/size"+str(L)+"seed"+str(seed)+".npy"
    data_x, label_x, data_u, label_u, data_all, label_all = load_data_train(L=L, dspth=root, seed=seed, name=name)
    
    ds_x = Cifar10(
        data=data_x,
        labels=label_x,
        is_train=True
    )
    sampler_x = RandomSampler(ds_x, replacement=True, num_samples=n_iters_per_epoch * batch_size)
    batch_sampler_x = BatchSampler(sampler_x, batch_size, drop_last=True)
    dl_x = torch.utils.data.DataLoader(
        ds_x,
        batch_sampler=batch_sampler_x,
        num_workers=1,
        pin_memory=True
    )
    
    ds_u = Cifar10(
        data=data_u,
        labels=label_u,
        is_train=True
    )
    sampler_u = RandomSampler(ds_u, replacement=True, num_samples=mu * n_iters_per_epoch * batch_size)
    batch_sampler_u = BatchSampler(sampler_u, batch_size * mu, drop_last=True)
    dl_u = torch.utils.data.DataLoader(
        ds_u,
        batch_sampler=batch_sampler_u,
        num_workers=2,
        pin_memory=True
    )
    
    ds_all = Cifar10(
        data=data_all,
        labels=label_all,
        is_train=True
    )
    #sampler_all = RandomSampler(ds_all, replacement=True, num_samples= mu_c * n_iters_per_epoch * batch_size)
    sampler_all = SequentialSampler(ds_all)
    batch_sampler_all = BatchSampler(sampler_all, batch_size * mu_c, drop_last=True)
    dl_all = torch.utils.data.DataLoader(
        ds_all,
        batch_sampler=batch_sampler_all,
        num_workers=2,
        pin_memory=True
    )
    return dl_x, dl_u, dl_all
Exemplo n.º 3
0
def main(args):
    # Set logging
    if not os.path.exists("./log"):
        os.makedirs("./log")

    log = set_log(args)
    tb_writer = SummaryWriter('./log/tb_{0}'.format(args.log_name))

    # Set seed
    set_seed(args.seed, cudnn=args.make_deterministic)

    # Set sampler
    sampler = BatchSampler(args, log)

    # Set policy
    policy = CaviaMLPPolicy(
        input_size=int(np.prod(sampler.observation_space.shape)),
        output_size=int(np.prod(sampler.action_space.shape)),
        hidden_sizes=(args.hidden_size, ) * args.num_layers,
        num_context_params=args.num_context_params,
        device=args.device)

    # Initialise baseline
    baseline = LinearFeatureBaseline(
        int(np.prod(sampler.observation_space.shape)))

    # Initialise meta-learner
    metalearner = MetaLearner(sampler, policy, baseline, args, tb_writer)

    # Begin train
    train(sampler, metalearner, args, log, tb_writer)
Exemplo n.º 4
0
    def __init__(self,
                 dataset,
                 batch_size=1,
                 shuffle=False,
                 sampler=None,
                 batch_sampler=None,
                 num_workers=0,
                 collate_fn=default_collate,
                 pin_memory=False,
                 drop_last=False):
        self.dataset = dataset
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.collate_fn = collate_fn
        self.pin_memory = pin_memory
        self.drop_last = drop_last

        if batch_sampler is not None:
            if batch_size > 1 or shuffle or sampler is not None or drop_last:
                raise ValueError('batch_sampler is mutually exclusive with '
                                 'batch_size, shuffle, sampler, and drop_last')

        if sampler is not None and shuffle:
            raise ValueError('sampler is mutually exclusive with shuffle')

        if batch_sampler is None:
            if sampler is None:
                if shuffle:
                    sampler = RandomSampler(dataset)
                else:
                    sampler = SequentialSampler(dataset)
            batch_sampler = BatchSampler(sampler, batch_size, drop_last)

        self.sampler = sampler
        self.batch_sampler = batch_sampler
Exemplo n.º 5
0
def dataloader(image_datasets, P, K, val_batch_size):
    """ 
    Loads the dataloaders to iterate on for each of the training and validation datasets.
    Args:
        image_datasets (dictionary): Dictionary accomodating the two datasets (objects).
        batch_size (integer): Number of images to yield per batch for the validation set.
        P, K (integer): Number of identities and samples respectively.
    Returns:
        A dictionary of two keys ('train', 'val') accommodating the dataloaders.
    """
    dataloaders_dict = {x: None for x in ['train', 'val']}
    print("\nInitializing Sampler and dataloaders_dict...")
    sampler = BatchSampler(image_datasets['train'], P, K)
    dataloaders_dict['val'] = torch.utils.data.DataLoader(
        image_datasets['val'],
        batch_size=val_batch_size,
        shuffle=False,
        num_workers=4)
    dataloaders_dict['train'] = torch.utils.data.DataLoader(
        image_datasets['train'],
        batch_sampler=sampler,
        num_workers=4,
        pin_memory=True)
    print("Sampler and dataloaders_dict are Loaded. \n")
    return (dataloaders_dict)
Exemplo n.º 6
0
 def __init__(self, dataset, batch_size=1, shuffle=False, drop_last=False):
     super(BaseDataGenerator, self).__init__()
     self.dataset = dataset
     self.index_sampler = BatchSampler(dataset,
                                       batch_size=batch_size,
                                       shuffle=shuffle,
                                       drop_last=drop_last)
     self._sampler_iter = iter(self.index_sampler)
Exemplo n.º 7
0
def _train(dic_exp_conf, dic_agent_conf, dic_traffic_env_conf, dic_path):
    '''
        Perform meta-testing for MAML, Metalight, Random, and Pretrained 

        Arguments:
            dic_exp_conf:           dict,   configuration of this experiment
            dic_agent_conf:         dict,   configuration of agent
            dic_traffic_env_conf:   dict,   configuration of traffic environment
            dic_path:               dict,   path of source files and output files
    '''

    random.seed(dic_agent_conf['SEED'])
    np.random.seed(dic_agent_conf['SEED'])
    tf.set_random_seed(dic_agent_conf['SEED'])

    sampler = BatchSampler(dic_exp_conf=dic_exp_conf,
                           dic_agent_conf=dic_agent_conf,
                           dic_traffic_env_conf=dic_traffic_env_conf,
                           dic_path=dic_path,
                           batch_size=args.fast_batch_size,
                           num_workers=args.num_workers)

    policy = config.DIC_AGENTS[args.algorithm](
        dic_agent_conf=dic_agent_conf,
        dic_traffic_env_conf=dic_traffic_env_conf,
        dic_path=dic_path)

    metalearner = MetaLearner(sampler,
                              policy,
                              dic_agent_conf=dic_agent_conf,
                              dic_traffic_env_conf=dic_traffic_env_conf,
                              dic_path=dic_path)

    if dic_agent_conf['PRE_TRAIN']:
        if not dic_agent_conf['PRE_TRAIN_MODEL_NAME'] == 'random':
            params = pickle.load(
                open(
                    os.path.join(
                        'model', 'initial', "common",
                        dic_agent_conf['PRE_TRAIN_MODEL_NAME'] + '.pkl'),
                    'rb'))
            metalearner.meta_params = params
            metalearner.meta_target_params = params

    tasks = dic_exp_conf['TRAFFIC_IN_TASKS']

    episodes = None
    for batch_id in range(dic_exp_conf['NUM_ROUNDS']):
        tasks = [dic_exp_conf['TRAFFIC_FILE']]
        if dic_agent_conf['MULTI_EPISODES']:
            episodes = metalearner.sample_meta_test(tasks[0], batch_id,
                                                    episodes)
        else:
            episodes = metalearner.sample_meta_test(tasks[0], batch_id)
Exemplo n.º 8
0
def dataloader(image_datasets, P, K, val_batch_size):  
    """ 
    Loads the dataloaders to iterate on for each of the training and validation datasets.
    Args:
        image_datasets (dictionary): Dictionary accomodating the two datasets (objects).
        batch_size (integer): Number of images to yield per batch for the validation set.
        P, K (integer): Number of identities and samples respectively.
    Returns:
        A dictionary of two keys ('train', 'val') accommodating the dataloaders.
    """    
    dataloaders = {}
    train_sampler = BatchSampler(image_datasets['train'], P, K)
    dataloaders['train']  = torch.utils.data.DataLoader(image_datasets['train'], batch_sampler = train_sampler, num_workers = 4)
    dataloaders['val'] =  torch.utils.data.DataLoader(image_datasets['val'], batch_size = val_batch_size, num_workers = 4)
    return(dataloaders)
Exemplo n.º 9
0
def _train(dic_exp_conf, dic_agent_conf, dic_traffic_env_conf, dic_path):

    random.seed(dic_agent_conf['SEED'])
    np.random.seed(dic_agent_conf['SEED'])
    tf.set_random_seed(dic_agent_conf['SEED'])

    sampler = BatchSampler(dic_exp_conf=dic_exp_conf,
                           dic_agent_conf=dic_agent_conf,
                           dic_traffic_env_conf=dic_traffic_env_conf,
                           dic_path=dic_path,
                           batch_size=args.fast_batch_size,
                           num_workers=args.num_workers)

    policy = config.DIC_AGENTS[args.algorithm](
        dic_agent_conf=dic_agent_conf,
        dic_traffic_env_conf=dic_traffic_env_conf,
        dic_path=dic_path
    )

    sampler.reset_task([dic_traffic_env_conf["TRAFFIC_FILE"]], 0, reset_type='test')
    sampler.sample_sotl(policy, dic_traffic_env_conf["TRAFFIC_FILE"])
Exemplo n.º 10
0
def main(experiment_path, dataset_path, config_path, restore_path, workers):
    logging.basicConfig(level=logging.INFO)
    config = Config.from_json(config_path)
    fix_seed(config.seed)

    train_data = pd.concat([
        load_data(os.path.join(dataset_path, 'train-clean-100'),
                  workers=workers),
        load_data(os.path.join(dataset_path, 'train-clean-360'),
                  workers=workers),
    ])
    eval_data = pd.concat([
        load_data(os.path.join(dataset_path, 'dev-clean'), workers=workers),
    ])

    if config.vocab == 'char':
        vocab = CharVocab(CHAR_VOCAB)
    elif config.vocab == 'word':
        vocab = WordVocab(train_data['syms'], 30000)
    elif config.vocab == 'subword':
        vocab = SubWordVocab(10000)
    else:
        raise AssertionError('invalid config.vocab: {}'.format(config.vocab))

    train_transform = T.Compose([
        ApplyTo(['sig'], T.Compose([
            LoadSignal(SAMPLE_RATE),
            ToTensor(),
        ])),
        ApplyTo(['syms'], T.Compose([
            VocabEncode(vocab),
            ToTensor(),
        ])),
        Extract(['sig', 'syms']),
    ])
    eval_transform = train_transform

    train_dataset = TrainEvalDataset(train_data, transform=train_transform)
    eval_dataset = TrainEvalDataset(eval_data, transform=eval_transform)

    train_data_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_sampler=BatchSampler(train_data,
                                   batch_size=config.batch_size,
                                   shuffle=True,
                                   drop_last=True),
        num_workers=workers,
        collate_fn=collate_fn)

    eval_data_loader = torch.utils.data.DataLoader(
        eval_dataset,
        batch_sampler=BatchSampler(eval_data, batch_size=config.batch_size),
        num_workers=workers,
        collate_fn=collate_fn)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = Model(SAMPLE_RATE, len(vocab))
    model_to_save = model
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
    model.to(device)
    if restore_path is not None:
        load_weights(model_to_save, restore_path)

    if config.opt.type == 'adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     config.opt.lr,
                                     weight_decay=1e-4)
    elif config.opt.type == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(),
                                    config.opt.lr,
                                    momentum=0.9,
                                    weight_decay=1e-4)
    else:
        raise AssertionError('invalid config.opt.type {}'.format(
            config.opt.type))

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer,
        len(train_data_loader) * config.epochs)

    # ==================================================================================================================
    # main loop

    train_writer = SummaryWriter(os.path.join(experiment_path, 'train'))
    eval_writer = SummaryWriter(os.path.join(experiment_path, 'eval'))
    best_wer = float('inf')

    for epoch in range(config.epochs):
        if epoch % 10 == 0:
            logging.info(experiment_path)

        # ==============================================================================================================
        # training

        metrics = {
            'loss': Mean(),
            'fps': Mean(),
        }

        model.train()
        t1 = time.time()
        for (sigs, labels), (sigs_mask, labels_mask) in tqdm(
                train_data_loader,
                desc='epoch {} training'.format(epoch),
                smoothing=0.01):
            sigs, labels = sigs.to(device), labels.to(device)
            sigs_mask, labels_mask = sigs_mask.to(device), labels_mask.to(
                device)

            logits, etc = model(sigs, labels[:, :-1], sigs_mask,
                                labels_mask[:, :-1])

            loss = compute_loss(input=logits,
                                target=labels[:, 1:],
                                mask=labels_mask[:, 1:],
                                smoothing=config.label_smoothing)
            metrics['loss'].update(loss.data.cpu().numpy())

            lr = np.squeeze(scheduler.get_lr())

            optimizer.zero_grad()
            loss.mean().backward()
            optimizer.step()
            scheduler.step()

            t2 = time.time()
            metrics['fps'].update(1 / ((t2 - t1) / sigs.size(0)))
            t1 = t2

        with torch.no_grad():
            metrics = {k: metrics[k].compute_and_reset() for k in metrics}
            print('[EPOCH {}][TRAIN] {}'.format(
                epoch, ', '.join('{}: {:.4f}'.format(k, metrics[k])
                                 for k in metrics)))
            for k in metrics:
                train_writer.add_scalar(k, metrics[k], global_step=epoch)
            train_writer.add_scalar('learning_rate', lr, global_step=epoch)

            train_writer.add_image('spectras',
                                   torchvision.utils.make_grid(
                                       etc['spectras'],
                                       nrow=compute_nrow(etc['spectras']),
                                       normalize=True),
                                   global_step=epoch)
            for k in etc['weights']:
                w = etc['weights'][k]
                train_writer.add_image('weights/{}'.format(k),
                                       torchvision.utils.make_grid(
                                           w,
                                           nrow=compute_nrow(w),
                                           normalize=True),
                                       global_step=epoch)

            for i, (true, pred) in enumerate(
                    zip(labels[:, 1:][:4].detach().data.cpu().numpy(),
                        np.argmax(logits[:4].detach().data.cpu().numpy(),
                                  -1))):
                print('{}:'.format(i))
                text = vocab.decode(
                    take_until_token(true.tolist(), vocab.eos_id))
                print(colored(text, 'green'))
                text = vocab.decode(
                    take_until_token(pred.tolist(), vocab.eos_id))
                print(colored(text, 'yellow'))

        # ==============================================================================================================
        # evaluation

        metrics = {
            # 'loss': Mean(),
            'wer': Mean(),
        }

        model.eval()
        with torch.no_grad(), Pool(workers) as pool:
            for (sigs, labels), (sigs_mask, labels_mask) in tqdm(
                    eval_data_loader,
                    desc='epoch {} evaluating'.format(epoch),
                    smoothing=0.1):
                sigs, labels = sigs.to(device), labels.to(device)
                sigs_mask, labels_mask = sigs_mask.to(device), labels_mask.to(
                    device)

                logits, etc = model.infer(sigs,
                                          sigs_mask,
                                          sos_id=vocab.sos_id,
                                          eos_id=vocab.eos_id,
                                          max_steps=labels.size(1) + 10)

                # loss = compute_loss(
                #     input=logits, target=labels[:, 1:], mask=labels_mask[:, 1:], smoothing=config.label_smoothing)
                # metrics['loss'].update(loss.data.cpu().numpy())

                wer = compute_wer(input=logits,
                                  target=labels[:, 1:],
                                  vocab=vocab,
                                  pool=pool)
                metrics['wer'].update(wer)

        with torch.no_grad():
            metrics = {k: metrics[k].compute_and_reset() for k in metrics}
            print('[EPOCH {}][EVAL] {}'.format(
                epoch, ', '.join('{}: {:.4f}'.format(k, metrics[k])
                                 for k in metrics)))
            for k in metrics:
                eval_writer.add_scalar(k, metrics[k], global_step=epoch)

            eval_writer.add_image('spectras',
                                  torchvision.utils.make_grid(
                                      etc['spectras'],
                                      nrow=compute_nrow(etc['spectras']),
                                      normalize=True),
                                  global_step=epoch)
            for k in etc['weights']:
                w = etc['weights'][k]
                eval_writer.add_image('weights/{}'.format(k),
                                      torchvision.utils.make_grid(
                                          w,
                                          nrow=compute_nrow(w),
                                          normalize=True),
                                      global_step=epoch)

        save_model(model_to_save, experiment_path)
        if metrics['wer'] < best_wer:
            best_wer = metrics['wer']
            save_model(model_to_save,
                       mkdir(os.path.join(experiment_path, 'best')))
Exemplo n.º 11
0
def main(args):
    print('starting....')

    utils.set_seed(args.seed, cudnn=args.make_deterministic)

    continuous_actions = (args.env_name in ['AntVel-v1', 'AntDir-v1',
                                            'AntPos-v0', 'HalfCheetahVel-v1', 'HalfCheetahDir-v1',
                                            '2DNavigation-v0'])

    # subfolders for logging
    method_used = 'maml' if args.maml else 'cavia'
    num_context_params = str(args.num_context_params) + '_' if not args.maml else ''
    output_name = num_context_params + 'lr=' + str(args.fast_lr) + 'tau=' + str(args.tau)
    output_name += '_' + datetime.datetime.now().strftime('%d_%m_%Y_%H_%M_%S')
    dir_path = os.path.dirname(os.path.realpath(__file__))
    log_folder = os.path.join(os.path.join(dir_path, 'logs'), args.env_name, method_used, output_name)
    save_folder = os.path.join(os.path.join(dir_path, 'saves'), output_name)
    if not os.path.exists(save_folder):
        os.makedirs(save_folder)
    if not os.path.exists(log_folder):
        os.makedirs(log_folder)

    # initialise tensorboard writer
    writer = SummaryWriter(log_folder)

    # save config file
    with open(os.path.join(save_folder, 'config.json'), 'w') as f:
        config = {k: v for (k, v) in vars(args).items() if k != 'device'}
        config.update(device=args.device.type)
        json.dump(config, f, indent=2)
    with open(os.path.join(log_folder, 'config.json'), 'w') as f:
        config = {k: v for (k, v) in vars(args).items() if k != 'device'}
        config.update(device=args.device.type)
        json.dump(config, f, indent=2)

    sampler = BatchSampler(args.env_name, batch_size=args.fast_batch_size, num_workers=args.num_workers,
                           device=args.device, seed=args.seed)

    if continuous_actions:
        if not args.maml:
            policy = CaviaMLPPolicy(
                int(np.prod(sampler.envs.observation_space.shape)),
                int(np.prod(sampler.envs.action_space.shape)),
                hidden_sizes=(args.hidden_size,) * args.num_layers,
                num_context_params=args.num_context_params,
                device=args.device
            )
        else:
            policy = NormalMLPPolicy(
                int(np.prod(sampler.envs.observation_space.shape)),
                int(np.prod(sampler.envs.action_space.shape)),
                hidden_sizes=(args.hidden_size,) * args.num_layers
            )
    else:
        if not args.maml:
            raise NotImplementedError
        else:
            policy = CategoricalMLPPolicy(
                int(np.prod(sampler.envs.observation_space.shape)),
                sampler.envs.action_space.n,
                hidden_sizes=(args.hidden_size,) * args.num_layers)

    # initialise baseline
    baseline = LinearFeatureBaseline(int(np.prod(sampler.envs.observation_space.shape)))

    # initialise meta-learner
    metalearner = MetaLearner(sampler, policy, baseline, gamma=args.gamma, fast_lr=args.fast_lr, tau=args.tau,
                              device=args.device)

    for batch in range(args.num_batches):

        # get a batch of tasks
        tasks = sampler.sample_tasks(num_tasks=args.meta_batch_size)

        # do the inner-loop update for each task
        # this returns training (before update) and validation (after update) episodes
        episodes, inner_losses = metalearner.sample(tasks, first_order=args.first_order)

        # take the meta-gradient step
        outer_loss = metalearner.step(episodes, max_kl=args.max_kl, cg_iters=args.cg_iters,
                                      cg_damping=args.cg_damping, ls_max_steps=args.ls_max_steps,
                                      ls_backtrack_ratio=args.ls_backtrack_ratio)

        # -- logging

        curr_returns = total_rewards(episodes, interval=True)
        print('   return after update: ', curr_returns[0][1])

        # Tensorboard
        writer.add_scalar('policy/actions_train', episodes[0][0].actions.mean(), batch)
        writer.add_scalar('policy/actions_test', episodes[0][1].actions.mean(), batch)

        writer.add_scalar('running_returns/before_update', curr_returns[0][0], batch)
        writer.add_scalar('running_returns/after_update', curr_returns[0][1], batch)

        writer.add_scalar('running_cfis/before_update', curr_returns[1][0], batch)
        writer.add_scalar('running_cfis/after_update', curr_returns[1][1], batch)

        writer.add_scalar('loss/inner_rl', np.mean(inner_losses), batch)
        writer.add_scalar('loss/outer_rl', outer_loss.item(), batch)

        # -- evaluation

        # evaluate for multiple update steps
        if batch % args.test_freq == 0:
            test_tasks = sampler.sample_tasks(num_tasks=args.meta_batch_size)
            test_episodes = metalearner.test(test_tasks, num_steps=args.num_test_steps,
                                             batch_size=args.test_batch_size, halve_lr=args.halve_test_lr)
            all_returns = total_rewards(test_episodes, interval=True)
            for num in range(args.num_test_steps + 1):
                writer.add_scalar('evaluation_rew/avg_rew ' + str(num), all_returns[0][num], batch)
                writer.add_scalar('evaluation_cfi/avg_rew ' + str(num), all_returns[1][num], batch)

            print('   inner RL loss:', np.mean(inner_losses))
            print('   outer RL loss:', outer_loss.item())

        # -- save policy network
        with open(os.path.join(save_folder, 'policy-{0}.pt'.format(batch)), 'wb') as f:
            torch.save(policy.state_dict(), f)
Exemplo n.º 12
0
def metalight_train(dic_exp_conf, dic_agent_conf, _dic_traffic_env_conf,
                    _dic_path, tasks, batch_id):
    '''
        metalight meta-train function 

        Arguments:
            dic_exp_conf:           dict,   configuration of this experiment
            dic_agent_conf:         dict,   configuration of agent
            _dic_traffic_env_conf:  dict,   configuration of traffic environment
            _dic_path:              dict,   path of source files and output files
            tasks:                  list,   traffic files name in this round 
            batch_id:               int,    round number
    '''
    tot_path = []
    tot_traffic_env = []
    for task in tasks:
        dic_traffic_env_conf = copy.deepcopy(_dic_traffic_env_conf)
        dic_path = copy.deepcopy(_dic_path)
        dic_path.update({
            "PATH_TO_DATA":
            os.path.join(dic_path['PATH_TO_DATA'],
                         task.split(".")[0])
        })
        # parse roadnet
        dic_traffic_env_conf["ROADNET_FILE"] = dic_traffic_env_conf[
            "traffic_category"]["traffic_info"][task][2]
        dic_traffic_env_conf["FLOW_FILE"] = dic_traffic_env_conf[
            "traffic_category"]["traffic_info"][task][3]
        roadnet_path = os.path.join(
            dic_path['PATH_TO_DATA'], dic_traffic_env_conf["traffic_category"]
            ["traffic_info"][task][2])  # dic_traffic_env_conf['ROADNET_FILE'])
        lane_phase_info = parse_roadnet(roadnet_path)
        dic_traffic_env_conf["LANE_PHASE_INFO"] = lane_phase_info[
            "intersection_1_1"]
        dic_traffic_env_conf["num_lanes"] = int(
            len(lane_phase_info["intersection_1_1"]["start_lane"]) /
            4)  # num_lanes per direction
        dic_traffic_env_conf["num_phases"] = len(
            lane_phase_info["intersection_1_1"]["phase"])

        dic_traffic_env_conf["TRAFFIC_FILE"] = task

        tot_path.append(dic_path)
        tot_traffic_env.append(dic_traffic_env_conf)

    sampler = BatchSampler(dic_exp_conf=dic_exp_conf,
                           dic_agent_conf=dic_agent_conf,
                           dic_traffic_env_conf=tot_traffic_env,
                           dic_path=tot_path,
                           batch_size=args.fast_batch_size,
                           num_workers=args.num_workers)

    policy = config.DIC_AGENTS[args.algorithm](
        dic_agent_conf=dic_agent_conf,
        dic_traffic_env_conf=tot_traffic_env,
        dic_path=tot_path)

    metalearner = MetaLearner(sampler,
                              policy,
                              dic_agent_conf=dic_agent_conf,
                              dic_traffic_env_conf=tot_traffic_env,
                              dic_path=tot_path)

    if batch_id == 0:
        params = pickle.load(
            open(os.path.join(dic_path['PATH_TO_MODEL'], 'params_init.pkl'),
                 'rb'))
        params = [params] * len(policy.policy_inter)
        metalearner.meta_params = params
        metalearner.meta_target_params = params

    else:
        params = pickle.load(
            open(
                os.path.join(dic_path['PATH_TO_MODEL'],
                             'params_%d.pkl' % (batch_id - 1)), 'rb'))
        params = [params] * len(policy.policy_inter)
        metalearner.meta_params = params
        period = dic_agent_conf['PERIOD']
        target_id = int((batch_id - 1) / period)
        meta_params = pickle.load(
            open(
                os.path.join(dic_path['PATH_TO_MODEL'],
                             'params_%d.pkl' % (target_id * period)), 'rb'))
        meta_params = [meta_params] * len(policy.policy_inter)
        metalearner.meta_target_params = meta_params

    metalearner.sample_metalight(tasks, batch_id)
Exemplo n.º 13
0
from policy import PolicyGradientModel, clone_policy
from sampler import BatchSampler
import multiprocessing as mp

sampler = BatchSampler('Maze-v0',
                       batch_size=20,
                       num_workers=mp.cpu_count() - 1)

print(sampler.envs.observation_space.shape)
print(sampler.envs.action_space.shape)

tasks = sampler.sample_tasks(num_tasks=40)