Пример #1
0
def run():

    gdc_1 = main.load_gold_data(Config1_1)
    gdc_1 = main.transform_gold_data(Config1_1, gdc_1)
    gdc_1 = main.transform_gold_data(Config1_2, gdc_1)
    gdc = GoldDataContainer(cats_list=gdc_1.cats_list)
    gdc = gold_data_manager.merge_assuming_identical_categories(gdc, gdc_1)

    gdc_2 = main.load_gold_data(Config2)
    gdc_2 = main.transform_gold_data(Config2, gdc_2)
    gdc = gold_data_manager.merge_assuming_identical_categories(gdc, gdc_2)

    gdc_3 = main.load_gold_data(Config3)
    gdc_3 = main.transform_gold_data(Config3, gdc_3)
    gdc = gold_data_manager.merge_assuming_identical_categories(gdc, gdc_3)

    gdc_4 = main.load_gold_data(Config4)
    gdc_4 = main.transform_gold_data(Config4, gdc_4)
    gdc = gold_data_manager.merge_assuming_identical_categories(gdc, gdc_4)

    gdc_5 = main.load_gold_data(Config5)
    gdc_5 = main.transform_gold_data(Config5, gdc_5)
    gdc = gold_data_manager.merge_assuming_identical_categories(gdc, gdc_5)

    gdc_6 = main.load_gold_data(Config6)
    gdc_6 = main.transform_gold_data(Config6, gdc_6)
    gdc = gold_data_manager.merge_assuming_identical_categories(gdc, gdc_6)

    trainer = main.init_trainer(ConfigTrain, cats_list=gdc.cats_list)
    main.run_training(config=ConfigTrain,
                      trainer=trainer,
                      gold_data_container=gdc)

    embed()
Пример #2
0
def run():

    gdc = main.load_gold_data(ConfigBase)
    gdc = main.transform_gold_data(ConfigBase, gdc)

    trainer = main.init_trainer(ConfigTdc100, cats_list=gdc.cats_list)
    main.run_training(config=ConfigTdc100, trainer=trainer, gold_data_container=gdc)

    trainer = main.init_trainer(ConfigTdc80, cats_list=gdc.cats_list)
    main.run_training(config=ConfigTdc80, trainer=trainer, gold_data_container=gdc)
def run():

    gdc = main.load_gold_data(ConfigSub)
    gdc = main.transform_gold_data(ConfigSub, gdc)

    for i in range(30):

        if i == 0:
            ConfigSub.should_load_model = False
            ConfigSub.should_create_model = True
        else:
            ConfigSub.should_load_model = True
            ConfigSub.should_create_model = False

        trainer = main.init_trainer(config=ConfigSub, cats_list=gdc.cats_list)
        main.run_training(ConfigSub, trainer, gdc)
Пример #4
0
def train(trainer1, trainer2):

    gdc = main.load_gold_data(ConfigTrainCompareBase)
    gdc = main.transform_gold_data(ConfigTrainCompareBase, gdc)

    if trainer1 is None:
        ConfigTrainCompareBase.should_load_model = False
        ConfigTrainCompareBase.should_create_model = True
        trainer1 = main.init_trainer(ConfigTrainCompare1,
                                     cats_list=gdc.cats_list)
        trainer2 = main.init_trainer(ConfigTrainCompare2,
                                     cats_list=gdc.cats_list)

    main.run_training(ConfigTrainCompare1, trainer1, gdc)
    main.run_training(ConfigTrainCompare2, trainer2, gdc)

    return trainer1, trainer2
Пример #5
0
def test_nec():
    env = ENV()
    key_size = 4

    seed = 1

    np.random.seed(seed)
    torch.manual_seed(seed)

    net = nn.Linear(env.observation_space.shape[0], key_size)

    config = {
        "env": env,
        "env_name": "test_env",
        "exp_name": "test",
        "device": torch.device("cpu"),
        "max_steps": 40,
        "initial_epsilon": 1,
        "final_epsilon": 0.5,
        "epsilon_anneal_start": 1,
        "epsilon_anneal_end": 2,
        "start_learning_step": 1,
        "replay_frequency": 1,
        "eval_frequency": 1000000,  # no eval for now
        ###### NEC AGENT CONFIG #################
        "train_eps": 1,  # initializing agent to be fully exploratory
        "eval_eps": 0,
        "num_actions": env.action_space.n,
        "observation_shape": env.observation_space.shape[0],
        "replay_buffer_size": 20,
        "batch_size": 3,
        "discount": 1,
        "horizon": 1,
        "learning_rate": 0.1,
        ###### NEC CONFIG #######################
        "embedding_net": net,
        ###### DND CONFIG #######################
        "dnd_capacity": 15,
        "num_neighbours": 1,
        "key_size": key_size,
        "alpha": 0.99,
    }

    # perform experiment
    agent = run_training(config, True)

    # check learned q_values for all states
    q_values = agent.get_q_values(env._obs[[0, 1, 2, 0, 1, 2]],
                                  [0, 0, 0, 1, 1, 1])
    expected_values = np.array([5.0, 2.0, 3.0, 3.0, 5.0, 1.0])

    print(f'Expected: {expected_values}')
    print(f'Got: {q_values}')
    assert np.allclose(q_values, expected_values, atol=0.2)
"""
Created on Sat Jul 11 00:09:38 2020

@author: btayart
"""


#%% Redo all training with ignore_unlabeled set to True
#Train the ENet model with vanilia CamVid
from main import run_training
from custom import CustomArgs
import torch
args = CustomArgs(resume=False, batch_size=6, print_step=False,
                  ignore_unlabeled = True,
                  model_type="ENet", name="ENet_allclasses")
run_training(args)
torch.cuda.empty_cache()
# BEST VALIDATION
# Epoch: 170
# Mean IoU: 0.651799197633767

args = CustomArgs(resume=False, batch_size=6, print_step=False, weight_decay=2e-3,
                  ignore_unlabeled = True,
                  model_type="ENet", name="ENet_allclasses_wd2")
run_training(args)
torch.cuda.empty_cache()
# BEST VALIDATION
# Epoch: 270
# Mean IoU: 0.6653864847519912

# Train the ENet model on CamVid with 'people' classes dropped
Пример #7
0
def run():

    gdc = main.load_gold_data(ConfigSub)
    gdc = main.transform_gold_data(ConfigSub, gdc)
    trainer = main.init_trainer(config=ConfigSub, cats_list=gdc.cats_list)
    main.run_training(ConfigSub, trainer, gdc)
Пример #8
0
#! /usr/bin/python3
# -*- coding: utf-8 -*-
# @Time    : 2018/6/7 0007 14:51
# @Author  : jsz
# @Software: PyCharm

import main

main.rename()

name_test = 'testrandomtrain'
name_test1 = 'testrandomval'
tfrecords_file = '.\\testrandomtrain.tfrecords'
tfrecords_file1 = '.\\testrandomval.tfrecords'

test_dir = '.\\data\\train\\'
save_dir = '.\\'
test_dir1 = '.\\data\\val\\'
save_dir1= '.\\'

images, labels = main.get_file(test_dir)
main.convert_to_tfrecord(images, labels, save_dir, name_test)
images1, labels1 = main.get_file(test_dir1)
main.convert_to_tfrecord(images1, labels1, save_dir1, name_test1)

main.run_training(tfrecords_file,tfrecords_file1)