def main(): sess = tf.Session() config = get_config(is_train=True) mkdir(config.tmp_dir) mkdir(config.ckpt_dir) mkdir(config.result_dir) DIRNet_Model = DIRNet(sess, config, "DIRNet_tr", is_train=True) Train_Data = AryllaDataHandler("Data_1", is_train = True) Eval_Data = AryllaDataHandler("Data_2", is_train = False) txt_dir = config.result_dir for i in range(config.iteration): batch_x, batch_y = Train_Data.sample_pair(config.batch_size) train_loss = DIRNet_Model.fit(batch_x, batch_y) batch_x_eval, batch_y_eval = Eval_Data.sample_pair(config.batch_size) eval_loss = DIRNet_Model.fit(batch_x_eval, batch_y_eval) print("Iteration {} ==> training ncc : {}, evaluate ncc : {} ".format(i+1, round(train_loss, 8), round(eval_loss, 8))) with open(txt_dir+'/data.txt','a') as f: f.write(str(i)+" "+str(train_loss)+" "+str(eval_loss)+"\n") if (i+1) % 1000 == 0: DIRNet_Model.deploy(config.tmp_dir, batch_x, batch_y) DIRNet_Model.save(config.ckpt_dir) print("Model saved ... ...") for lable_id in range(2): result_i_dir = config.result_dir+"/{}".format(lable_id) DIRNet_Model.deploy(result_i_dir, batch_x_eval, batch_y_eval)
def main(): sess = tf.Session() config = get_config(is_train=True) mkdir(config.tmp_dir) mkdir(config.ckpt_dir) q = 100 reg = DIRNet(sess, config, "DIRNet", is_train=True) dh = MNISTDataHandler("MNIST_data", is_train=True) for i in range(config.iteration): batch_x, batch_y = dh.sample_pair(config.batch_size) loss = reg.fit(batch_x, batch_y) print("iter {:>6d} : {}".format(i + 1, loss)) if (i + 1) % 1000 == 0: reg.deploy(config.tmp_dir, batch_x, batch_y) reg.save(config.ckpt_dir)
def main(): sess = tf.Session() config = get_config(is_train=False) mkdir(config.result_dir) reg = DIRNet(sess, config, "DIRNet", is_train=False) reg.restore(config.ckpt_dir) dh = MNISTDataHandler("MNIST_data", is_train=False) for i in range(1): #10 result_i_dir = config.result_dir + "/{}".format(i) mkdir(result_i_dir) batch_x, batch_y = dh.sample_pair(config.batch_size, i) reg.deploy(result_i_dir, batch_x, batch_y)
def main(): config = get_config(is_train=True) mkdir(config.tmp_dir) mkdir(config.ckpt_dir) model = DIRNet(config) transform = tf.Compose([tf.Resize([16, 16]), tf.ToTensor()]) train_loader = DataLoader(MNIST('mnist', train=True, download=True, transform=transform), batch_size=train_batch) test_loader = DataLoader(MNIST('mnist', train=False, download=True, transform=transform), batch_size=test_batch) for batch, (data, label) in enumerate(train_loader): if batch == 0: num_images = 16000 # num_images = 300 digit_0 = data.index_select(0, label.eq(0).nonzero().squeeze())[:3000] digit_1 = data.index_select(0, label.eq(1).nonzero().squeeze())[:3000] digit_2 = data.index_select(0, label.eq(2).nonzero().squeeze())[:3000] digit_3 = data.index_select(0, label.eq(3).nonzero().squeeze())[:3000] digit_4 = data.index_select(0, label.eq(4).nonzero().squeeze())[:3000] digit_5 = data.index_select(0, label.eq(5).nonzero().squeeze())[:3000] digit_6 = data.index_select(0, label.eq(6).nonzero().squeeze())[:3000] digit_7 = data.index_select(0, label.eq(7).nonzero().squeeze())[:3000] digit_8 = data.index_select(0, label.eq(8).nonzero().squeeze())[:3000] digit_9 = data.index_select(0, label.eq(9).nonzero().squeeze())[:3000] digit = torch.stack([ digit_0, digit_1, digit_2, digit_3, digit_4, digit_5, digit_6, digit_7, digit_8, digit_9 ], dim=0) for batch, (data, label) in enumerate(test_loader): if batch == 0: num_images = 16000 # num_images = 300 digit_0_t = data.index_select( 0, label.eq(0).nonzero().squeeze())[:500] digit_1_t = data.index_select( 0, label.eq(1).nonzero().squeeze())[:500] digit_2_t = data.index_select( 0, label.eq(2).nonzero().squeeze())[:500] digit_3_t = data.index_select( 0, label.eq(3).nonzero().squeeze())[:500] digit_4_t = data.index_select( 0, label.eq(4).nonzero().squeeze())[:500] digit_5_t = data.index_select( 0, label.eq(5).nonzero().squeeze())[:500] digit_6_t = data.index_select( 0, label.eq(6).nonzero().squeeze())[:500] digit_7_t = data.index_select( 0, label.eq(7).nonzero().squeeze())[:500] digit_8_t = data.index_select( 0, label.eq(8).nonzero().squeeze())[:500] digit_9_t = data.index_select( 0, label.eq(9).nonzero().squeeze())[:500] digit_t = torch.stack([ digit_0_t, digit_1_t, digit_2_t, digit_3_t, digit_4_t, digit_5_t, digit_6_t, digit_7_t, digit_8_t, digit_9_t ], dim=0) optim = torch.optim.Adam(model.parameters(), lr=config.lr) scheduler = StepLR(optim, step_size=200, gamma=0.5) train_pr = MNISTDataHandler(digit) test_pr = MNISTDataHandler(digit_t) total_loss = 0 for i in range(config.iteration): batch_x, batch_y = train_pr.sample_pair(config.batch_size) optim.zero_grad() _, loss = model(batch_x, batch_y) loss.backward() optim.step() scheduler.step() total_loss += loss if (i + 1) % 100 == 0: print("iter {:>6d} : {}".format(i + 1, total_loss)) total_loss = 0 batch_x, batch_y = test_pr.sample_pair(config.batch_size) model.deploy(config.tmp_dir, batch_x, batch_y)