Exemplo n.º 1
0
def train_post_prototype(assignment):
    from config import get_prototype_config
    from model_torch import Post_Prototype_RCNN_L2
    from stream import Post_Data
    config = get_prototype_config()
    ## 1. data prepare
    trainData = Post_Data(is_train=True, **config)
    testData = Post_Data(is_train=False, **config)
    ### 2. model & train
    nn = Post_Prototype_RCNN_L2(assignment, **config)
    LR = config['LR']
    opt_ = torch.optim.SGD(nn.parameters(), lr=LR)
    best_auc = 0
    trainFeature_all, _ = trainData.get_all()
    testFeature, testLabel = testData.get_all()
    for i in range(config['train_iter']):
        nn.generate_prototype(trainFeature_all)
        feat, label = trainData.next()
        loss, _ = nn(feat, label)
        opt_.zero_grad()
        loss.backward()
        opt_.step()
        loss_value = loss.data[0]
        if i % 1 == 0:
            score = post_evaluate(testFeature, testLabel, nn, **config)
            best_auc = score if score > best_auc else best_auc
            print('iter {}, auc:{} (best:{})'.format(i,
                                                     str(score)[:5],
                                                     str(best_auc)[:5]),
                  end=' ')
            if i > 0:
                print('{} sec'.format(str(time() - t1)[:4]))
            t1 = time()
Exemplo n.º 2
0
    def test_conv_in_computational_graph(self):
        nn = NN(3)
        for param in nn.parameters():
            assert param.requires_grad

        afl = AffineCouplingLayer(3)
        for param in afl.parameters():
            assert param.requires_grad
Exemplo n.º 3
0
def parameter_stats(nn: nn.Module, print_indivual_params=True):
    """Prints various statistics about the paramters of a neural neural network"""
    p = torch.nn.utils.parameters_to_vector(nn.parameters()).detach().numpy()
    print(
        f"max/min = {max(p):.1f}/{min(p):.1f}, mean={np.mean(p):.1f}+/-{np.std(p):.2f}"
    )
    if print_indivual_params:
        for name, p in nn.named_parameters():
            p = p.detach().numpy()
            print(
                f"{name:40s}max/min = {np.max(p):.1f}/{np.min(p):.1f}, mean={np.mean(p):.1f}+/-{np.std(p):.2f}"
            )
Exemplo n.º 4
0
def train_weighted_post_prototype(assignment):
    from config import get_prototype_config
    from model_torch import Weighted_Post_Prototype_RCNN_L2
    from stream import Weighted_Post_Data
    config = get_prototype_config()
    every_iter = config['every_iter']
    ## 1. data prepare
    trainData = Weighted_Post_Data(is_train=True, **config)
    testData = Weighted_Post_Data(is_train=False, **config)
    ### 2. model & train
    nn = Weighted_Post_Prototype_RCNN_L2(assignment, **config)
    LR = config['LR']
    opt_ = torch.optim.SGD(nn.parameters(), lr=LR)
    best_auc = 0
    trainFeature_all, _, weight_all = trainData.get_all()
    testFeature, testLabel, _ = testData.get_all()
    for i in range(config['train_iter']):
        nn.generate_prototype(trainFeature_all, weight_all)
        feat, label, weight = trainData.next()
        loss, _ = nn(feat, label, weight)
        opt_.zero_grad()
        loss.backward()
        opt_.step()
        #loss_value = loss.data[0]
        if i % 1 == 0:
            score = post_evaluate(testFeature, testLabel, nn, **config)
            best_auc = score if score > best_auc else best_auc
            if i % every_iter == every_iter - 1 and i > 0:
                print('iter {}, auc:{} (best:{})'.format(
                    i,
                    str(score)[:5],
                    str(best_auc)[:5]),
                      end=' \n')
            #	print('{} sec'.format(str(time() - t1)[:4]))
            #t1 = time()

    trainData.restart()
    reweight = []
    for i in range(trainData.batch_number):
        feat, label, weight = trainData.next()
        reweight.extend(nn.measure_similarity(feat))
    #print(reweight)
    reweight = normalize_weight(reweight, config['upper'], config['lower'])
    #print(reweight)
    return reweight
Exemplo n.º 5
0
    def train(self, grad_mod):
        '''
        Computes an epoch of the full training step
        '''
        for model in self.mrf.nnformulas:
            model.train()

        self.optimizer.zero_grad()
        preds = self.forward()
        # print(preds[1])

        loss = criterion[0](preds[0], y_true)

        loss.backward()
        if (grad_mod):
            gradient = self.grad(preds)
            for fidx, nn in enumerate(self.mrf.nnformulas):
                for par in nn.parameters():
                    par.grad *= gradient[fidx]

        self.optimizer.step()
        return loss
Exemplo n.º 6
0
        x = self.pool(F.relu(self.conv2(x)))
        #x = F.relu(self.conv2(x))

        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))

        x = x.view(x.size(0), -1)
        x = F.relu(self.fc(self.do(x)))
        #x = F.relu(self.fc(x))

        x = self.out(x)
        return x


if __name__ == '__main__':
    #Test code
    nn = net(2, manifold=geoopt.manifolds.PoincareBall())
    #nn = net(2,manifold=geoopt.manifolds.Euclidean())
    #nn = net(2,manifold=geoopt.manifolds.Stiefel())

    x = torch.randn(128, 1024, 2)

    y = nn(x)

    opt = geoopt.optim.RiemannianAdam(nn.parameters(), lr=1, stabilize=1)
    opt.zero_grad()
    opt.step()
    print(nn)

    #print(y)
Exemplo n.º 7
0
def test(models, epoch, f=None):
    global num_tests
    num_tests += 1

    class MStat:
        def __init__(self, model):
            model.eval()
            self.model = model
            self.correct = 0

            class Stat:
                def __init__(self, d, dnm):
                    self.domain = d
                    self.name = dnm
                    self.width = 0
                    self.max_eps = None
                    self.safe = 0
                    self.proved = 0
                    self.time = 0

            self.domains = [
                Stat(h.parseValues(d, goals), h.catStrs(d))
                for d in args.test_domain
            ]

    model_stats = [MStat(m) for m in models]
    dict_map = dict(np.load("./dataset/AG/dict_map.npy").item())
    lines = open("./dataset/en.key1").readlines()
    adjacent_keys = [[] for i in range(len(dict_map))]
    for line in lines:
        tmp = line.strip().split()
        ret = set(tmp[1:]).intersection(dict_map.keys())
        ids = []
        for x in ret:
            ids.append(dict_map[x])
        adjacent_keys[dict_map[tmp[0]]].extend(ids)

    num_its = 0
    saved_data_target = []
    for data, target in test_loader:
        if num_its >= args.test_size:
            break

        if num_tests == 1:
            saved_data_target += list(zip(list(data), list(target)))

        num_its += data.size()[0]
        if num_its % 100 == 0:
            print(num_its, model_stats[0].domains[0].safe * 100.0 / num_its)
        if args.test_swap_delta > 0:
            length = data.size()[1]
            data = data.repeat(1, length)
            for i in data:
                for j in range(length - 1):
                    for _ in range(args.test_swap_delta):
                        t = np.random.randint(0, length)
                        while len(adjacent_keys[int(i[t])]) == 0:
                            t = np.random.randint(0, length)
                        cid = int(i[t])
                        i[j * length + t] = adjacent_keys[cid][0]
            target = (target.view(-1, 1).repeat(1, length)).view(-1)
            data = data.view(-1, length)

        if h.use_cuda:
            data, target = data.cuda().to_dtype(), target.cuda()

        for m in model_stats:

            with torch.no_grad():
                pred = m.model(data).vanillaTensorPart().max(1, keepdim=True)[
                    1]  # get the index of the max log-probability
                m.correct += pred.eq(target.data.view_as(pred)).sum()

            for stat in m.domains:
                timer = Timer(shouldPrint=False)
                with timer:

                    def calcData(data, target):
                        box = stat.domain.box(data,
                                              w=m.model.w,
                                              model=m.model,
                                              untargeted=True,
                                              target=target).to_dtype()
                        with torch.no_grad():
                            bs = m.model(box)
                            org = m.model(data).vanillaTensorPart().max(
                                1, keepdim=True)[1]
                            stat.width += bs.diameter().sum().item(
                            )  # sum up batch loss
                            stat.proved += bs.isSafe(org).sum().item()
                            stat.safe += bs.isSafe(target).sum().item()
                            # stat.max_eps += 0 # TODO: calculate max_eps

                    if m.model.net.neuronCount(
                    ) < 5000 or stat.domain in SYMETRIC_DOMAINS:
                        calcData(data, target)
                    else:
                        if args.test_swap_delta > 0:
                            length = data.size()[1]
                            pre_stat = copy.deepcopy(stat)
                            for i, (d, t) in enumerate(zip(data, target)):
                                calcData(d.unsqueeze(0), t.unsqueeze(0))
                                if (i + 1) % length == 0:
                                    d_proved = (stat.proved -
                                                pre_stat.proved) // length
                                    d_safe = (stat.safe -
                                              pre_stat.safe) // length
                                    d_width = (stat.width -
                                               pre_stat.width) / length
                                    stat.proved = pre_stat.proved + d_proved
                                    stat.safe = pre_stat.safe + d_safe
                                    stat.width = pre_stat.width + d_width
                                    pre_stat = copy.deepcopy(stat)
                        else:
                            for d, t in zip(data, target):
                                calcData(d.unsqueeze(0), t.unsqueeze(0))
                stat.time += timer.getUnitTime()

    l = num_its  # len(test_loader.dataset)
    for m in model_stats:
        if args.lr_multistep:
            m.model.lrschedule.step()

        pr_corr = float(m.correct) / float(l)
        if args.use_schedule:
            m.model.lrschedule.step(1 - pr_corr)

        h.printBoth(
            ('Test: {:12} trained with {:' + str(largest_domain) +
             '} - Avg sec/ex {:1.12f}, Accuracy: {}/{} ({:3.1f}%)').format(
                 m.model.name, m.model.ty.name, m.model.speed, m.correct, l,
                 100. * pr_corr),
            f=f)

        model_stat_rec = ""
        for stat in m.domains:
            pr_safe = stat.safe / l
            pr_proved = stat.proved / l
            pr_corr_given_proved = pr_safe / pr_proved if pr_proved > 0 else 0.0
            h.printBoth((
                "\t{:" + str(largest_test_domain) +
                "} - Width: {:<36.16f} Pr[Proved]={:<1.3f}  Pr[Corr and Proved]={:<1.3f}  Pr[Corr|Proved]={:<1.3f} {}Time = {:<7.5f}"
            ).format(
                stat.name, stat.width / l, pr_proved, pr_safe,
                pr_corr_given_proved,
                "AvgMaxEps: {:1.10f} ".format(stat.max_eps / l)
                if stat.max_eps is not None else "", stat.time),
                        f=f)
            model_stat_rec += "{}_{:1.3f}_{:1.3f}_{:1.3f}__".format(
                stat.name, pr_proved, pr_safe, pr_corr_given_proved)
        prepedname = m.model.ty.name.replace(" ", "_").replace(
            ",", "").replace("(", "_").replace(")", "_").replace("=", "_")
        net_file = os.path.join(
            out_dir, m.model.name + "__" + prepedname + "_checkpoint_" +
            str(epoch) + "_with_{:1.3f}".format(pr_corr))

        h.printBoth("\tSaving netfile: {}\n".format(net_file + ".pynet"), f=f)

        if (num_tests % args.save_freq == 1 or args.save_freq
                == 1) and not args.dont_write and (num_tests > 1
                                                   or args.write_first):
            print("Actually Saving")
            torch.save(m.model.net, net_file + ".pynet")
            if args.save_dot_net:
                with h.mopen(args.dont_write, net_file + ".net", "w") as f2:
                    m.model.net.printNet(f2)
                    f2.close()
            if args.onyx:
                nn = copy.deepcopy(m.model.net)
                nn.remove_norm()
                torch.onnx.export(
                    nn,
                    h.zeros([1] + list(input_dims)),
                    net_file + ".onyx",
                    verbose=False,
                    input_names=["actual_input"] + [
                        "param" + str(i)
                        for i in range(len(list(nn.parameters())))
                    ],
                    output_names=["output"])

    if num_tests == 1 and not args.dont_write:
        img_dir = os.path.join(out_dir, "images")
        if not os.path.exists(img_dir):
            os.makedirs(img_dir)
        for img_num, (img, target) in zip(
                range(args.number_save_images),
                saved_data_target[:args.number_save_images]):
            sz = ""
            for s in img.size():
                sz += str(s) + "x"
            sz = sz[:-1]

            img_file = os.path.join(
                img_dir, args.dataset + "_" + sz + "_" + str(img_num))
            if img_num == 0:
                print("Saving image to: ", img_file + ".img")
            with open(img_file + ".img", "w") as imgfile:
                flatimg = img.view(h.product(img.size()))
                for t in flatimg.cpu():
                    print(decimal.Decimal(float(t)).__format__("f"),
                          file=imgfile)
            with open(img_file + ".class", "w") as imgfile:
                print(int(target.item()), file=imgfile)
Exemplo n.º 8
0
 def train(self):
     self.bnn.train(self.X, self.y)
     for nn in self.bnn.nns:
         for p in nn.parameters():
             p.requires_grad = False
Exemplo n.º 9
0
# In[9]:


n_actions = env.action_space.n
nn = reinforce(n_actions)


# In[12]:


num_episodes = 500
episode_durations = []
episode_mean = []
memory = ReplayMemory()
optimizer = optim.RMSprop(nn.parameters(), lr=8 * math.exp(-3))

# print(num_episodes)
for i in range(num_episodes):
    # print(i)
    state = torch.FloatTensor([env.reset()])
    for t in count():
        # print(state)
        action = select_action(state)
        next_state, reward, done, _ = env.step(action.item())
        reward = torch.tensor([reward], device=device)
        # Observe new state
        if not done:
            next_state = torch.FloatTensor([next_state])
        else:
            next_state = None
Exemplo n.º 10
0
        # self.bn = torch.nn.BatchNorm1d(nn_config[0])

        self.fc = nn.ModuleList()
        for x in range(0, len(af_config) - 1):
            if af_config[x] is not None:
                self.fc.append(af_config[x])
            self.fc.append(nn.Linear(nn_arch[x], nn_arch[x + 1]))
            if init_w[x] is not None:
                self.fc[-1].weight.data.uniform_(-init_w[x], init_w[x])
            else:
                s = self.fc[-1].weight.data.size()[0]
                s = 1. / np.sqrt(s)
                self.fc[-1].weight.data.uniform_(-s, s)

        if af_config[len(af_config) - 1] is not None:
            self.fc.append(af_config[len(af_config) - 1])

    def forward(self, x):
        for f in range(0, len(self.fc)):
            x = self.fc[f](x)
        return x


if __name__ == '__main__':
    nc = NN_CONFIG(nn_arch=[10, 100, 10], af_config=[None, nn.modules.ReLU(), nn.modules.ReLU()], init_w=[None, None],
                   lr=0.1, optim=None, load_path=None)

    nn = Net(nc.nn_arch, nc.af_config, nc.init_w)
    nn.to(torch.cuda.current_device())
    print(next(nn.parameters()).device)
Exemplo n.º 11
0
Arquivo: BO.py Projeto: xlchan/BNN
 def train(self):
     self.bnn.train(self.X,
                    (self.y - self.y.mean(dim=0)) / self.y.std(dim=0))
     for nn in self.bnn.nns:
         for p in nn.parameters():
             p.requires_grad = False
Exemplo n.º 12
0
# setup training set
trainset = HLWDataset('hlw/split/train.txt', opt.imagesize, training=True)
trainset_loader = torch.utils.data.DataLoader(trainset, shuffle=True, num_workers=6, batch_size=opt.batchsize)

# setup ng dsac estimator
loss = Loss(opt.imagesize) 
ngdsac = NGDSAC(opt.hypotheses, opt.inlierthreshold, opt.inlierbeta, opt.inlieralpha, loss, opt.invalidloss)

# setup network
nn = Model(opt.capacity)
nn.train()
nn = nn.cuda()

# optimizer and lr schedule (schedule offset handled further below)
optimizer = optim.Adam(nn.parameters(), lr=opt.learningrate)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=opt.schedulestep, gamma=0.5)

# keep track of training progress
train_log = open('log_'+opt.session+'.txt', 'w', 1)

iteration = 0
epochs = int(opt.iterations / len(trainset)) # number of epochs from number of target iterations 

for epoch in range(epochs):

	print('=== Epoch: ', epoch, '========================================')

	for inputs, labels, xStart, xEnd, imh, idx in trainset_loader:

		start_time = time.time()
Exemplo n.º 13
0
def test(models, epoch, f = None):
    global num_tests
    num_tests += 1
    class MStat:
        def __init__(self, model):
            model.eval()
            self.model = model
            self.correct = 0
            class Stat:
                def __init__(self, d, dnm):
                    self.domain = d
                    self.name = dnm
                    self.width = 0
                    self.max_eps = 0
                    self.safe = 0
                    self.proved = 0
                    self.time = 0
            self.domains = [ Stat(h.parseValues(domains,d), h.catStrs(d)) for d in args.test_domain ]
    model_stats = [ MStat(m) for m in models ]
        
    num_its = 0
    saved_data_target = []
    for data, target in test_loader:
        if num_its >= args.test_size:
            break

        if num_tests == 1:
            saved_data_target += list(zip(list(data), list(target)))
        
        num_its += data.size()[0]
        if h.use_cuda:
            data, target = data.cuda(), target.cuda()

        for m in model_stats:

            with torch.no_grad():
                pred = m.model(data).data.max(1, keepdim=True)[1] # get the index of the max log-probability
                m.correct += pred.eq(target.data.view_as(pred)).sum()

            for stat in m.domains:
                timer = Timer(shouldPrint = False)
                with timer:
                    def calcData(data, target):
                        box = stat.domain.box(data, m.model.w, model=m.model, untargeted = True, target=target)
                        with torch.no_grad():
                            bs = m.model(box)
                            org = m.model(data).max(1,keepdim=True)[1]
                            stat.width += bs.diameter().sum().item() # sum up batch loss
                            stat.proved += bs.isSafe(org).sum().item()
                            stat.safe += bs.isSafe(target).sum().item()
                            stat.max_eps += 0 # TODO: calculate max_eps

                    if m.model.net.neuronCount() < 5000 or stat.domain in SYMETRIC_DOMAINS:
                        calcData(data, target)
                    else:
                        for d,t in zip(data, target):
                            calcData(d.unsqueeze(0),t.unsqueeze(0))
                stat.time += timer.getUnitTime()
                
    l = num_its # len(test_loader.dataset)
    for m in model_stats:

        pr_corr = float(m.correct) / float(l)
        if args.use_schedule:
            m.model.lrschedule.step(1 - pr_corr)
        
        h.printBoth(('Test: {:12} trained with {:'+ str(largest_domain) +'} - Avg sec/ex {:1.12f}, Accuracy: {}/{} ({:3.1f}%)').format(
            m.model.name, m.model.ty.name,
            m.model.speed,
            m.correct, l, 100. * pr_corr), f = f)
        
        model_stat_rec = ""
        for stat in m.domains:
            pr_safe = stat.safe / l
            pr_proved = stat.proved / l
            pr_corr_given_proved = pr_safe / pr_proved if pr_proved > 0 else 0.0
            h.printBoth(("\t{:" + str(largest_test_domain)+"} - Width: {:<36.16f} Pr[Proved]={:<1.3f}  Pr[Corr and Proved]={:<1.3f}  Pr[Corr|Proved]={:<1.3f} AvgMaxEps: {:1.10f} Time = {:<7.5f}").format(
                stat.name, 
                stat.width / l, 
                pr_proved, 
                pr_safe, pr_corr_given_proved, 
                stat.max_eps / l,
                stat.time), f = f)
            model_stat_rec += "{}_{:1.3f}_{:1.3f}_{:1.3f}__".format(stat.name, pr_proved, pr_safe, pr_corr_given_proved)
        prepedname = m.model.ty.name.replace(" ", "_").replace(",", "").replace("(", "_").replace(")", "_").replace("=", "_")
        net_file = os.path.join(out_dir, m.model.name +"__" +prepedname + "_checkpoint_"+str(epoch)+"_with_{:1.3f}".format(pr_corr))

        h.printBoth("\tSaving netfile: {}\n".format(net_file + ".net"), f = f)

        if num_tests % args.save_freq == 1 or args.save_freq == 1 and not args.dont_write:
            torch.save(m.model.net, net_file + ".pynet")
            
            with h.mopen(args.dont_write, net_file + ".net", "w") as f2:
                m.model.net.printNet(f2)
                f2.close()
            if args.onyx:
                nn = copy.deepcopy(m.model.net)
                nn.remove_norm()
                torch.onnx.export(nn, h.zeros([1] + list(input_dims)), net_file + ".onyx", 
                                  verbose=False, input_names=["actual_input"] + ["param"+str(i) for i in range(len(list(nn.parameters())))], output_names=["output"])


    if num_tests == 1 and not args.dont_write:
        img_dir = os.path.join(out_dir, "images")
        if not os.path.exists(img_dir):
            os.makedirs(img_dir)
        for img_num,(img,target) in zip(range(args.number_save_images), saved_data_target[:args.number_save_images]):
            sz = ""
            for s in img.size():
                sz += str(s) + "x"
            sz = sz[:-1]

            img_file = os.path.join(img_dir, args.dataset + "_" + sz + "_"+ str(img_num))
            if img_num == 0:
                print("Saving image to: ", img_file + ".img")
            with open(img_file + ".img", "w") as imgfile:
                flatimg = img.view(h.product(img.size()))
                for t in flatimg.cpu():
                    print(decimal.Decimal(float(t)).__format__("f"), file=imgfile)
            with open(img_file + ".class" , "w") as imgfile:
                print(int(target.item()), file=imgfile)
Exemplo n.º 14
0
x = torch.randn(1000, 30)
y = torch.randn(1000, 10)


class Network(nn.Module):
	def __init__(self):
		super().__init__()
		self.w1 = nn.Linear(30, 50)
		self.w2 = nn.Linear(50, 10)
	def forward(self, x):
		h = torch.relu(self.w1(x))
		o = self.w2(h)
		return o

nn = Network()

op = torch.optim.Adam(nn.parameters(), lr=0.05)
t1 = time.time()
for i in range(2000):
	o = nn(x)
	l = torch.mean((o - y) ** 2)
	print(i, float(l))
	l.backward()
	op.step()

t2 = time.time()

print(t2 - t1)


Exemplo n.º 15
0
def main(tr_conf):
    import torch.nn as nn

    seed = 12345
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    tr = transforms.Compose([transforms.ToTensor()])
    train_data = MovingObjects("train", tr, seed)
    train_loader = DataLoader(train_data,
                              num_workers=tr_conf['b_s'],
                              batch_size=tr_conf['b_s'],
                              shuffle=True,
                              pin_memory=True)

    param_list = []
    in_chs = tr_conf['input_channels']
    flow_l3 = nn.DataParallel(
        _Glow(in_channels=4 * in_chs, mid_channels=512,
              num_steps=24)).to(device)
    flow_l2 = nn.DataParallel(
        _Glow(in_channels=8 * in_chs, mid_channels=512,
              num_steps=24)).to(device)
    flow_l1 = nn.DataParallel(
        _Glow(in_channels=16 * in_chs, mid_channels=512,
              num_steps=24)).to(device)

    nntheta3 = nn.DataParallel(
        NNTheta(encoder_ch_in=4 * in_chs,
                encoder_mode=tr_conf['encoder_mode'],
                h_ch_in=2 * in_chs,
                num_blocks=tr_conf['enc_depth'])).to(device)  # z1:2x32x32
    nntheta2 = nn.DataParallel(
        NNTheta(encoder_ch_in=8 * in_chs,
                encoder_mode=tr_conf['encoder_mode'],
                h_ch_in=4 * in_chs,
                num_blocks=tr_conf['enc_depth'])).to(device)  # z2:4x16x16
    nntheta1 = nn.DataParallel(
        NNTheta(encoder_ch_in=16 * in_chs,
                encoder_mode=tr_conf['encoder_mode'],
                num_blocks=tr_conf['enc_depth'])).to(device)

    model_path = '/b_test/azimi/results/VideoFlow/SMovement/exp12_2/sacred/snapshots/55.pth'
    if tr_conf['resume']:
        print('model loading ...')
        flow_l3.load_state_dict(torch.load(model_path)['glow_l3'])
        flow_l2.load_state_dict(torch.load(model_path)['glow_l2'])
        flow_l1.load_state_dict(torch.load(model_path)['glow_l1'])
        nntheta3.load_state_dict(torch.load(model_path)['nn_theta_l3'])
        nntheta2.load_state_dict(torch.load(model_path)['nn_theta_l2'])
        nntheta1.load_state_dict(torch.load(model_path)['nn_theta_l1'])
        print("****LOAD THE OPTIMIZER")

    glow = Glow(l3=flow_l3, l2=flow_l2, l1=flow_l1)
    nn_theta = NN_Theta(l3=nntheta3, l2=nntheta2, l1=nntheta1)

    for f_level in glow:
        param_list += list(f_level.parameters())

    for nn in nn_theta:
        param_list += list(nn.parameters())

    loss_fn = NLLLossVF()

    optimizer = torch.optim.Adam(param_list, lr=tr_conf['lr'])
    optimizer.load_state_dict(torch.load(model_path)['optimizer'])
    optimizer.zero_grad()

    # scheduler_step = sched.StepLR(optimizer, step_size=1, gamma=0.99)
    # linear_decay = sched.LambdaLR(optimizer, lambda s: 1. - s / 150000. )
    # linear_decay.step(global_step)

    # scheduler = sched.LambdaLR(optimizer, lambda s: min(1., s / 10000))
    # optimizer.load_state_dict(torch.load(model_path)['optimizer'])

    for epoch in range(tr_conf['starting_epoch'], tr_conf['n_epoch']):
        print("the learning rate for epoch {} is {}".format(
            epoch, get_lr(optimizer)))
        train_smovement(train_loader, glow, nn_theta, loss_fn, optimizer, None,
                        epoch)
Exemplo n.º 16
0
def train():

    if FLAGS.dnn_hidden_units:
        dnn_hidden_units = FLAGS.dnn_hidden_units.split(",")
        dnn_hidden_units = [
            int(dnn_hidden_unit_) for dnn_hidden_unit_ in dnn_hidden_units
        ]
    else:
        dnn_hidden_units = []

    # Get negative slope parameter for LeakyReLU
    neg_slope = FLAGS.neg_slope

    # use GPU if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Device :", device)

    # load data and sample the first batch
    cifar10 = cifar10_utils.get_cifar10(FLAGS.data_dir)
    x_np, y_np = cifar10['train'].next_batch(FLAGS.batch_size)
    x_np = x_np.reshape(FLAGS.batch_size, -1)  # batchsize * pixels per image

    # initialize MLP
    nn = MLP(x_np.shape[1], dnn_hidden_units, 10, neg_slope).to(device)
    crossEntropy = torch.nn.CrossEntropyLoss()

    # initialize optimizer
    optimizer = torch.optim.SGD(nn.parameters(), lr=FLAGS.learning_rate)

    # initialization for plotting and metrics
    test_accuracies = []
    training_losses = []
    training_accuracies = []
    test_losses = []

    # extract test data
    x_test, y_test_np = cifar10['test'].images, cifar10['test'].labels
    x_test = x_test.reshape(x_test.shape[0], -1)
    x_test = torch.from_numpy(x_test).to(device)
    y_test = torch.from_numpy(y_test_np).to(device)

    # perform the forward step, backward step and updating of weights max_steps number of times,
    for step in range(FLAGS.max_steps):
        if (step + 1) % 100 == 0:
            print(step + 1, "/", FLAGS.max_steps, "\n")

        optimizer.zero_grad()

        x = (torch.autograd.Variable(torch.from_numpy(x_np),
                                     requires_grad=1)).to(device)
        y = (torch.autograd.Variable(torch.from_numpy(y_np),
                                     requires_grad=1)).to(device)

        pred = nn(x).to(device)

        train_acc = accuracy(pred, y)

        # compute cross entropy loss
        labels = torch.max(y, 1)[1]
        loss = crossEntropy(pred, labels)

        training_accuracies.append(train_acc)
        training_losses.append(loss)

        # evaluation on test set
        if step % FLAGS.eval_freq == 0:
            test_accuracies, test_losses = eval_on_test(
                nn, crossEntropy, x_test, y_test, test_accuracies, test_losses)

        # get a next batch
        x_np, y_np = cifar10['train'].next_batch(FLAGS.batch_size)
        x_np = x_np.reshape(FLAGS.batch_size,
                            -1)  # batchsize * pixels per image

        loss.backward()
        optimizer.step()

    # compute loss and accuracy on the test set a final time
    test_accuracies, test_losses = eval_on_test(nn, crossEntropy, x_test,
                                                y_test, test_accuracies,
                                                test_losses)
    print("Maximum accuracy :", max(test_accuracies),
          "after %d steps\n" % (np.argmax(test_accuracies) * FLAGS.eval_freq))