def setUp(self):
        db.create_all()
        a_simple_model = SimpleModel()
        a_simple_model.app_name = "simple_app"

        db.session.add(a_simple_model)
        db.session.commit()
示例#2
0
 def test_missing(self):
     res = list(SimpleModel.get_all())
     self.failUnlessEqual(len(res),0)
     SimpleModel(int1=3).save()
     res = list(SimpleModel.get_all())
     self.failUnlessEqual(len(res),1)
     self.failUnlessEqual(res[0].int1,3)
示例#3
0
 def test_set_missing_field(self):
     SimpleModel({'i1':2,'_id':'timon'}).save()
     ob = SimpleModel.get_id('timon')
     ob.int2 = 15
     ob.save()
     ob = SimpleModel.get_id('timon')
     self.failUnlessEqual(ob.int2, 15)
示例#4
0
 def test_missing(self):
     res = list(SimpleModel.get_all())
     self.failUnlessEqual(len(res), 0)
     SimpleModel(int1=3).save()
     res = list(SimpleModel.get_all())
     self.failUnlessEqual(len(res), 1)
     self.failUnlessEqual(res[0].int1, 3)
示例#5
0
def persist(simple_data):

    model = SimpleModel(simple_data)
    model.save()

    last = model.query.order_by(model.id.desc()).first()

    return render_template('save_result.j2', last=last)
示例#6
0
 def test_ignored(self):
     o = SimpleModel(int1=17, i2=13, secret=42, keep=100)
     self.failUnlessEqual( 17, o.int1 )
     self.failUnlessEqual( 13, o.int2 )
     self.assertRaises(AttributeError, getattr, o, 'i2')
     self.failUnlessEqual( 100, o.keep )
     self.failUnlessEqual( 42, o.secret )
     self.failUnlessEqual( o.to_d(), {'i1':17,'i2':13,'keep':100})
    def test_serialization(self):
        self.mod.data = TEST_DATA
        self.mod._compute_symbol_counts()
        handle, filename = tempfile.mkstemp()
        self.mod.save_to_file(filename)

        test_mod = SimpleModel()
        test_mod.load_from_file(filename)
        self.assertEqual(test_mod.symbols, self.mod.symbols)
示例#8
0
 def test_update_object(self):
     #make sure that we replace objects when they are updated
     self.o1._id = "mustafa"
     self.o1.int1 = 1
     self.o1.int2 = 2
     self.o1.save()
     ob = SimpleModel.get_id("mustafa")
     ob.int2 = 3
     ob.save()
     ob = SimpleModel.get_id("mustafa")
     self.failUnlessEqual(3, ob.int2)
示例#9
0
 def test_remove_field(self):
     self.o2._id = "nala"
     self.o2.int1 = 2
     self.o2.int2 = 3
     self.o2.save()
     item = SimpleModel.get_id("nala")
     self.failUnlessEqual( item.int2, 3)
     item.int2 = None
     item.save()
     result = SimpleModel.get_id("nala")
     self.failUnlessEqual( result.int2, None)
示例#10
0
def adversarial_example(torch_dev_dataset, parameters, ckpt):
    print("Start part 3")
    # Question 3 - Adversarial example
    data_loader = DataLoader(dataset=torch_dev_dataset, batch_size=1)
    adversarial_model = SimpleModel()
    adversarial_model.load(path=ckpt)
    adversarial = Adversarial(adversarial_model, data_loader,
                              parameters['adversarial_epsilons'],
                              parameters['path_plots_adversarial'])
    adversarial.__attack__()
    adversarial.__plot_attack__()
    adversarial.__plot_examples__()
    print("End part 3")
示例#11
0
def dataset_fix(dataset, trainer, parameters, counters_dev, is_evaluating,
                name):
    mislabeled_train_loader = DataLoader(dataset=dataset)
    improved_train_model = SimpleModel()
    improved_train_model.load(path=trainer.ckpt)
    evaluator = evaluate(improved_train_model, mislabeled_train_loader,
                         parameters, counters_dev, "improved_model_eval",
                         is_evaluating)
    if name == "train":
        output = parameters['fixed_dataset']
    else:
        output = parameters['fixed_dataset_dev']
    fix_dataset(evaluator, dataset, output)
    torch_dataset_fixed = get_dataset_as_torch_dataset(path=output)
    return torch_dataset_fixed
class TestSimpleModel(unittest.TestCase):
    def setUp(self):
        self.mod = SimpleModel()

    def tearDown(self):
        self.mod = None

    def test_serialization(self):
        self.mod.data = TEST_DATA
        self.mod._compute_symbol_counts()
        handle, filename = tempfile.mkstemp()
        self.mod.save_to_file(filename)

        test_mod = SimpleModel()
        test_mod.load_from_file(filename)
        self.assertEqual(test_mod.symbols, self.mod.symbols)
class TestSimpleModel(unittest.TestCase):
    def setUp(self):
        self.model = SimpleModel()
        inp = tf.random.normal(shape=(5, 256 * 256 * 3))
        self.out = self.model.forward(inp)

    def test_forward(self):
        with tf.Session() as sess:
            sess.run(tf.initialize_all_variables())
            print(sess.run(self.out))
示例#14
0
def build_simple(should_setup, check_nan, unroll_batch_num, encode_key, no_per_note):
    if encode_key == "abs":
        enc = AbsoluteSequentialEncoding(constants.BOUNDS.lowbound, constants.BOUNDS.highbound)
        inputs = [input_parts.BeatInputPart(),input_parts.ChordShiftInputPart()]
    elif encode_key == "cot":
        enc = CircleOfThirdsEncoding(constants.BOUNDS.lowbound, (constants.BOUNDS.highbound-constants.BOUNDS.lowbound)//12)
        inputs = [input_parts.BeatInputPart(),input_parts.ChordShiftInputPart()]
    elif encode_key == "rel":
        enc = RelativeJumpEncoding()
        inputs = None
    sizes = [(200,10),(200,10)] if (encode_key == "rel" and not no_per_note) else [(300,0),(300,0)]
    bounds = constants.NoteBounds(48, 84) if encode_key == "cot" else constants.BOUNDS
    return SimpleModel(enc, sizes, bounds=bounds, inputs=inputs, dropout=0.5, setup=should_setup, nanguard=check_nan, unroll_batch_num=unroll_batch_num)
示例#15
0
def create_model(network_name):
    """
    Kind-of model factory.
    Edit it to add more models.
    :param network_name: The string input from the terminal
    :return: The model
    """
    if network_name == 'simple':
        return SimpleModel()
    elif network_name == 'dqn':
        return DqnModel()
    elif network_name == 'monte_carlo':
        return MonteCarloModel()
    else:
        raise Exception('net {} is not known'.format(network_name))
示例#16
0
def run_task(task):
    # set output path and ensure a directory exists for this path
    output_path = os.path.join(os.environ['modNN_DIR'], 'results',
                               task['name'])
    if not os.path.isdir(output_path):
        os.makedirs(output_path)

    # initialise the model
    if 'graph' in task:
        # use the given graph structure
        model = GraphModel(task['name'],
                           task['data_provider'],
                           task['input_handlers'],
                           task['module_handlers'],
                           task['output_handlers'],
                           task['graph'],
                           add_summaries=True)
    else:
        # no graph structure given, create simple chain graph
        model = SimpleModel(task['name'],
                            task['data_provider'],
                            task['input_handler'],
                            task['module_handlers'],
                            task['output_handler'],
                            add_summaries=True)

    # report the model built
    print(model)

    # reload or train the model
    if os.path.isfile(
            os.path.join(os.environ['modNN_DIR'], 'results',
                         model.experiment_name, 'model',
                         'trained_model.ckpt.index')):
        # load the model
        model.restore_model()
    else:
        # train the model
        model.train(num_epochs=_num_epochs)

    return model
示例#17
0
def playing_with_learning_rate(train_loader, parameters):
    print("Start part 2")
    # Question 2 - Playing with learning rate
    print("Playing with learning rate")
    models = [SimpleModel() for _ in range(len(parameters['lrs']))]
    for idx in range(len(parameters['lrs'])):
        models[idx].load(parameters['pretrained_path'])
        model_name = "model_{}".format(idx)
        lr = parameters['lrs'][idx]
        print("model: {}. lr: {}".format(model_name, lr))
        trainer = Trainer(models[idx], train_loader, parameters['criterion'],
                          lr, parameters['betas'], parameters['epochs'],
                          parameters['batch_size'], parameters['num_classes'],
                          parameters['epsilon'], model_name,
                          parameters['path_lrs'])
        trainer.__train__(False)
        plot(trainer.losses, "{} Loss".format(trainer.name), "loss", "epoch",
             parameters['path_lrs'])
        plot(trainer.accuracies, "{} Accuracy".format(trainer.name),
             "accuracy", "epoch", parameters['path_lrs'])
    print("End Playing with learning rate")
    print("End part 2")
示例#18
0
def train_and_eval(dev_loader, train_loader, counters, parameters, name,
                   is_evaluating, path):
    model = SimpleModel()
    model.load(path)
    evaluate(model, dev_loader, counters, parameters, name, is_evaluating)
    return training_loop(model, train_loader, parameters, name)
示例#19
0
def full_process(torch_train_dataset, torch_dev_dataset, train_loader,
                 dev_loader, counters_train, counters_dev, parameters):
    # Load model
    pre_trained = SimpleModel()
    pre_trained.load(path=parameters['pretrained_path'])

    # Evaluate given model and train it
    evaluation(pre_trained, dev_loader, parameters, counters_dev,
               "pretrain_model", True)
    trainer = training_loop(pre_trained, train_loader, parameters,
                            "pretrain_model")

    # Load the best state of the model we've trained and evaluate
    trained = SimpleModel()
    trained.load(path=trainer.ckpt)
    evaluation(trained, dev_loader, parameters, counters_dev, "trained_model",
               True)

    # Model improvements - get mislabeled images and fix them
    print("Improving given model:")
    print("Fix train dataset")
    torch_train_dataset_fixed = dataset_fix(torch_train_dataset, trainer,
                                            parameters, counters_dev, False,
                                            "train")
    train_loader = create_data_loader(torch_train_dataset_fixed,
                                      counters_train, parameters, True)
    print(
        "=============================================================================="
    )
    print("Fix dev dataset")
    torch_dev_dataset_fixed = dataset_fix(torch_dev_dataset, trainer,
                                          parameters, counters_dev, False,
                                          "dev")
    dev_loader = create_data_loader(torch_dev_dataset_fixed, counters_train,
                                    parameters, True)
    print(
        "=============================================================================="
    )

    counters_train, counters_dev = inspect_dataset(torch_train_dataset_fixed,
                                                   torch_dev_dataset)

    # Test given model after fixing datasets
    print("Test given model after fixing datasets")
    test_model = SimpleModel()
    test_model.load(path=parameters['pretrained_path'])
    trainer = training_loop(test_model, train_loader, parameters,
                            "train_improved_model_train")
    ckpt = trainer.ckpt
    eval_model = SimpleModel()
    eval_model.load(path=ckpt)
    evaluation(eval_model, dev_loader, parameters, counters_dev,
               "improved_model_eval", True)
    print("End part 1")
    print(
        "=============================================================================="
    )
    playing_with_learning_rate(train_loader, parameters)
    adversarial_example(torch_dev_dataset, parameters, ckpt)
    return ckpt
示例#20
0
class TestBasicModelCreationAndAssignment(unittest.TestCase):

    def setUp(self):
        self.o1 = SimpleModel()
        self.o2 = SimpleModel()
        self.o3 = SimpleModel()

    def test_simple_save(self):
        self.o1.int1 = 44
        self.o1.save()
        self.failIfEqual(self.o1._id, None)

    def test_update_object(self):
        #make sure that we replace objects when they are updated
        self.o1._id = "mustafa"
        self.o1.int1 = 1
        self.o1.int2 = 2
        self.o1.save()
        ob = SimpleModel.get_id("mustafa")
        ob.int2 = 3
        ob.save()
        ob = SimpleModel.get_id("mustafa")
        self.failUnlessEqual(3, ob.int2)

    def test_merge(self):
        f = FunModel(_id='7',enum='red')
        f.save()
        g = FunModel(_id='7',dic={'three':4})
        g.merge()
        ob = FunModel.get_id('7').to_d()
        del ob['ca']
        ob.pop('_rev',None)
        self.failUnlessEqual(ob, dict(_id='7', e='red', d={'three':4}))

    def test_missing_fields(self):
        obj1 = SimpleModel({'_id':'simba','i1':2})
        obj1.save()
        ob = SimpleModel.get_id('simba')
        self.failUnlessEqual(ob.int2, None)

    def test_set_missing_field(self):
        SimpleModel({'i1':2,'_id':'timon'}).save()
        ob = SimpleModel.get_id('timon')
        ob.int2 = 15
        ob.save()
        ob = SimpleModel.get_id('timon')
        self.failUnlessEqual(ob.int2, 15)

    def test_remove_field(self):
        self.o2._id = "nala"
        self.o2.int1 = 2
        self.o2.int2 = 3
        self.o2.save()
        item = SimpleModel.get_id("nala")
        self.failUnlessEqual( item.int2, 3)
        item.int2 = None
        item.save()
        result = SimpleModel.get_id("nala")
        self.failUnlessEqual( result.int2, None)

    def test_get_all(self):
        for name in ['pumba','zazu','rafiki']:
            m = PersonModel(name=name, age=(10+len(name)))
            m.save()

        people = sorted(PersonModel.get_all(),key=attrgetter('age'))
        self.failUnlessEqual( people[0].name, 'zazu')
        self.failUnlessEqual( people[0].age, 14)
        self.failUnlessEqual( people[1].name, 'pumba')
        self.failUnlessEqual( people[1].age, 15)
        self.failUnlessEqual( people[2].name, 'rafiki')
        self.failUnlessEqual( people[2].age, 16)
        people = list(PersonModel.get_all(limit=2))
        self.failUnlessEqual( len(people), 2)

    def test_fun_model(self):
        dic = {"one":2, 'three':"four", 'five':["six",7]}
        names = ['Shenzi', 'Banzai', 'ed']
        now = datetime.datetime.utcnow()
        fun = FunModel(
                _id="fun",
                enum="red",
                real=3.14,
                dic=dic,
                names=names,
                )
        fun.part=PersonModel(name="scar", age=32)
        fun.save()
        fun = FunModel.get_id("fun")
        self.failUnlessEqual( fun.enum, 'red')
        self.failUnlessEqual( fun.real, 3.14)
        self.failUnlessEqual( fun.dic, dic)
        dt = abs(fun.created-now)
        self.failUnless( dt.days==0 and dt.seconds==0 )
        self.failUnlessEqual( fun.names, names)
        self.failUnlessEqual( fun.part.name, "scar")
        self.failUnlessEqual( fun.part.age, 32)
def start_training():

    if request.method == "POST":
        args_dict = request.get_json()
        print(args_dict)

        agent_type = "naive"  # TODO: Make variable
        agent_path = Path("experiments", agent_type, args_dict["name"])
        agent_config = BaseAgentConfig(config_dict=args_dict)

        # Get git version
        repo = git.Repo(search_parent_directories=True)
        sha = repo.head.object.hexsha

        # Create experiment folder and handle old results
        deleted_old = False
        if agent_path.exists():
            if args_dict["replace"]:
                shutil.rmtree(agent_path)
                deleted_old = True
            else:
                experiment_info = {
                    "mean_test_reward":
                    None,
                    "description":
                    f"The experiment {agent_path} already exists. "
                    f"Change experiment name or use the replace "
                    f"option to overwrite.",
                    "git_hash":
                    sha,
                    "train_time":
                    None
                }

                return experiment_info, 200

        agent_path.mkdir(parents=True)

        # Save experiments configurations and start experiment log
        prepare_file_logger(logger, logging.INFO,
                            Path(agent_path, "experiment.log"))
        logger.info(
            f"Running {agent_type} policy gradient on SimpleContinuous")
        if deleted_old:
            logger.info(f"Deleted old experiment in {agent_path}")
        agent_config.log_configurations(logger)
        experiment_config_file = Path(agent_path, "configurations.json")
        logger.info(
            f"Saving experiment configurations to {experiment_config_file}")
        agent_config.to_json_file(experiment_config_file)

        env = BaseSimpleContinuousEnvironment(target_action=float(
            agent_config.true_action),
                                              min_reward=-10)
        policy = SimpleModel(model_path=Path(agent_path, "model"),
                             layer_sizes=agent_config.hidden_layer_sizes,
                             learning_rate=agent_config.learning_rate,
                             actions_size=agent_config.actions_size,
                             hidden_activation=agent_config.hidden_activation,
                             mu_activation=agent_config.mu_activation,
                             sigma_activation=agent_config.sigma_activation,
                             start_mu=agent_config.start_mu,
                             start_sigma=agent_config.start_sigma)
        agent = NaivePolicyGradientAgent(env=env,
                                         agent_path=agent_path,
                                         policy=policy,
                                         agent_config=agent_config)

        start_time = time.time()
        test_reward = agent.train_policy(
            train_steps=agent_config.training_steps,
            experience_size=agent_config.experience_size,
            show_every=agent_config.show_every,
            save_policy_every=agent_config.save_policy_every,
            minibatch_size=agent_config.minibatch_size)
        train_time = time.time() - start_time

        experiment_info = {
            "mean_test_reward": float(test_reward),
            "description": agent_config.desc,
            "git_hash": sha,
            "train_time": train_time
        }

        with open(Path(agent_path, "experiment_information.json"),
                  "w") as outfile:
            json.dump(experiment_info, outfile, indent=4)

        logger.removeHandler(logger.handlers[1])

        return experiment_info, 200
示例#22
0
from models import SimpleModel
from transforms import AddGaussianNoise

datasets = [MyDataset(get_dataset_as_array())]

for _ in range(3):
    noisey_dataset = MyDataset(get_dataset_as_array(), AddGaussianNoise())
    datasets.append(noisey_dataset)

dataset = torch.utils.data.ConcatDataset(datasets)

datasetloader = torch.utils.data.DataLoader(dataset,
                                            batch_size=50,
                                            shuffle=True)

model = SimpleModel()
checkpoint = torch.load('./data/pre_trained.ckpt')
model.load_state_dict(checkpoint['model_state_dict'])

correct_test = 0
total = 0
with torch.no_grad():
    for data in datasetloader:
        images, labels = data
        outputs = model(images)
        predicted = torch.argmax(outputs.data, dim=1)
        total += labels.size(0)
        correct_test += (predicted == labels).sum().item()
print(f'total: {total} ')
print(f'Model Accruacy is : {100 * correct_test / total:.2f}%')
 def setUp(self):
     self.model = SimpleModel()
     inp = tf.random.normal(shape=(5, 256 * 256 * 3))
     self.out = self.model.forward(inp)
示例#24
0
 def test_missing_fields(self):
     obj1 = SimpleModel({'_id':'simba','i1':2})
     obj1.save()
     ob = SimpleModel.get_id('simba')
     self.failUnlessEqual(ob.int2, None)
示例#25
0
 def setUp(self):
     self.o1 = SimpleModel()
     self.o2 = SimpleModel()
     self.o3 = SimpleModel()
 def setUp(self):
     self.mod = SimpleModel()
示例#27
0
class TestBasicModelCreationAndAssignment(unittest.TestCase):

    def setUp(self):
        self.o1 = SimpleModel()
        self.o2 = SimpleModel()
        self.o3 = SimpleModel()

    def test_simple_assign_bogus(self):
        self.assertRaises(ValueError, setattr, self.o1, 'int1', 'bogus')

    def test_simple_assign_obvious(self):
        # test for an obvious integer
        self.o1.int1 = 5
        self.assertEqual(self.o1.int1, 5)

        # now test for changing to an obvious integer
        self.o1.int1 = 9
        self.assertEqual(self.o1.int1, 9)

    def test_simple_arith(self):
        self.o1.int1 = 6
        self.failUnlessEqual( 7, 1+self.o1.int1)
        self.failUnless( self.o1.int1 > 3)

    def test_simple_assign_to_multiple(self):
        '''
        Since we're doing some odd introspection and catching assignments,
        let's ensure that we're actually creating new objects when a value is
        assigned instead of just overwriting the _value of a previous object
        '''
        self.o1.int1 = 8
        self.o2.int1 = 3
        self.assertNotEqual(self.o1.int1, self.o2.int1)

    def test_dict_creation(self):
        self.o1.int1 = 1
        self.failUnlessEqual(self.o1.to_d(), {'i1':1})

    def test_init_from_dict(self):
        obj1 = SimpleModel({'int1':2})
        obj2 = SimpleModel(dict(i1=3,i2=7))
        self.failUnlessEqual( 2, obj1.int1 )
        self.failUnlessEqual( 3, obj2.int1 )
        self.failUnlessEqual( 7, obj2.int2 )

    def test_ignored(self):
        o = SimpleModel(int1=17, i2=13, secret=42, keep=100)
        self.failUnlessEqual( 17, o.int1 )
        self.failUnlessEqual( 13, o.int2 )
        self.assertRaises(AttributeError, getattr, o, 'i2')
        self.failUnlessEqual( 100, o.keep )
        self.failUnlessEqual( 42, o.secret )
        self.failUnlessEqual( o.to_d(), {'i1':17,'i2':13,'keep':100})

    def test_fun_model(self):
        fun = FunModel()
        fun.part = PersonModel()
        self.assertRaises(TypeError, setattr, fun, 'enum', 'green')
        self.assertRaises(ValueError, setattr, fun, 'real', 'i')
        self.assertRaises(TypeError, setattr, fun, 'dic', [2,3])
        self.assertRaises(TypeError, setattr, fun, 'created', [])
        self.assertRaises(TypeError, setattr, fun, 'names', [7,8])
        self.assertRaises(TypeError, setattr, fun, 'names', 13)
        self.assertEqual(fun.enum, None)
        self.assertEqual(fun.part.age, 7)
        fun.part.age=100
        self.assertEqual(fun.part.age, 100)
        fun.part = {'n':'jeff'}
        self.assertEqual(fun.part.age, 7)
        self.assertEqual(fun.part.name, 'jeff')

    def test_date(self):
        fun = FunModel(date=datetime(2005,1,2,13))
        def do_to_d(df):
            return fun.to_d(dateformat=df)['dt']
        self.assertEqual(fun.date, datetime(2005,1,2,13))
        self.assertEqual(do_to_d('datetime'), datetime(2005,1,2,13))
        fun.date = [2005,1,2,16]
        self.assertEqual(fun.date, datetime(2005,1,2,16))
        self.assertEqual(do_to_d('list'),(2005,1,2,16,0,0))
        fun.date = 1104660000
        self.assertEqual(fun.date, datetime(2005,1,2,10))
        self.assertEqual(do_to_d('epoch'),1104660000)

    def test_long_names(self):
        self.o1.int1 = 3
        self.failUnlessEqual(self.o1.to_d(long_names=True), {'int1':3})
        self.failUnlessEqual(SimpleModel.long_names,
                {'i1':'int1','i2':'int2','_rev':'_rev','_id':'_id'})
示例#28
0
文件: train.py 项目: Xelanos/APML
import torch
from torch.utils.data import RandomSampler

EPOCHS = 120

labels = label_names()
dataset = get_dataset_as_torch_dataset()

trainset, testset = torch.utils.data.random_split(dataset, [len(dataset) - 500, 500])

trainloader = torch.utils.data.DataLoader(trainset, batch_size=20, shuffle=True)
testloader = torch.utils.data.DataLoader(trainset, batch_size=500, shuffle=True)

classes = tuple(labels.values())

model = SimpleModel()
loss_fn = torch.nn.CrossEntropyLoss()
learning_rate = 1e-4
# optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)
optimizer = torch.optim.Adagrad(model.parameters(), lr=1e-2, lr_decay=1e-5)


for epoch in range(EPOCHS):  # loop over the dataset multiple times
    print(f"EPOCH {epoch + 1}/{EPOCHS}", end='')

    running_loss = 0.0
    correct_train = 0
    for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
示例#29
0
 def test_init_from_dict(self):
     obj1 = SimpleModel({'int1':2})
     obj2 = SimpleModel(dict(i1=3,i2=7))
     self.failUnlessEqual( 2, obj1.int1 )
     self.failUnlessEqual( 3, obj2.int1 )
     self.failUnlessEqual( 7, obj2.int2 )
示例#30
0
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


dataiter = iter(train_loader)
images, labels = dataiter.next()
imshow(torchvision.utils.make_grid(images), 0.1307, 0.3081)
print(labels)

# Get Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Model
from models import SimpleModel
model = SimpleModel().to(device)

# Loss function
criterion = nn.CrossEntropyLoss()

# Optimizer
optimizer = optim.SGD(model.parameters(), lr=learning_rate)

num_steps = len(train_loader)

for epoch in range(num_epochs):

    # ---------- TRAINING ----------
    # set model to training
    model.train()
示例#31
0
import tensorflow as tf
import numpy as np

from data.tf_datasets import OmniglotDataset
from models import SimpleModel

model_address = './saved_models/simple_model-1000'
model = SimpleModel()

omniglot_dataset = OmniglotDataset()
test_dataset = omniglot_dataset.get_test_dataset()
train_task, val_task, train_labels, val_labels = test_dataset.get_supervised_meta_learning_tasks(
    meta_batch_size=1, n=6, k=2)

tf.summary.image('task',
                 tf.reshape(train_task, (-1, 28, 28, 1)),
                 max_outputs=12)

model.forward(train_task)
model.define_update_op(train_labels, with_batch_norm_dependency=True)

for item in tf.global_variables():
    tf.summary.histogram(item.name, item)

merged_summary = tf.summary.merge_all()
train_writer = tf.summary.FileWriter('./adaptaion_summary/train',
                                     tf.get_default_graph())
test_writer = tf.summary.FileWriter('./adaptaion_summary/test')

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
from dataset import get_train_dataloader, get_test_dataloader
from utils import parse_args

if __name__ == '__main__':
    args = parse_args()
    use_cuda = not args.use_cpu and torch.cuda.is_available()
    device = 'cuda' if use_cuda else 'cpu'
    bs = args.train_batch_size

    train_dataloader = get_train_dataloader(
        os.path.join(args.data_dir, 'train/'), args.train_batch_size,
        args.augmentation)
    test_dataloader = get_test_dataloader(os.path.join(args.data_dir, 'test/'),
                                          args.test_batch_size)

    model = SimpleModel(use_bn=args.use_bn).to(device)

    wandb.init(project="classifying-celebrities", config=args)
    wandb.watch(model, log='all')
    config = wandb.config

    loss_function = CrossEntropyLoss(reduction='mean')
    optimizer = dispatch_optimizer(model, args)
    lr_scheduler = dispatch_lr_scheduler(optimizer, args)

    iteration = 0
    training_accuracy = compute_accuracy(model, train_dataloader, device)
    test_accuracy = compute_accuracy(model, test_dataloader, device)
    wandb.log({'training accuracy': training_accuracy}, step=iteration * bs)
    wandb.log({'test_accuracy': test_accuracy}, step=iteration * bs)