def train(self, config): cwd = os.getcwd() optim = tf.train.AdamOptimizer(config.learning_rate, beta1=config.beta1).minimize(self.loss, var_list=self.vars) init = tf.global_variables_initializer() self.sess.run(init) # self.sum = tf.summary.merge([self.loss_sum]) self.sum = tf.summary.merge_all() self.writer = tf.summary.FileWriter(config.checkpoint_dir, self.sess.graph) # counter = 1 load_dir = config.checkpoint_dir # + "/PASCAL_VOC_2012" if self.load(load_dir, config=config): print(" [*] Load SUCCESS") else: print(" [!] Load failed...") testimg_loc = cwd + '/coco/%d/test/img_pre' % (self.image_size) testimg = os.listdir(testimg_loc) testimg.sort() testlabel_loc = cwd + '/coco/%d/test/knn_img/%d' % (self.image_size, self.nei_num) testlabel = os.listdir(testlabel_loc) testlabel.sort() testgt_loc = cwd + '/coco/%d/test/gt_img' % (self.image_size) testgt = os.listdir(testgt_loc) testgt.sort() testgt = testgt[16:21] ds_testimg = np.copy(testimg[16:21]) # for display, no shuffle ds_testlabel = np.copy(testlabel[16:21]) # for display, no shuffle for epoch in range(config.epoch): start_epoch = time.time() dataset = 1#randint(0, 1) if dataset == 0: # rot = randint(-3, 3) # flip = randint(0, 1) # if flip == 0: # name = 'train' # else: # name = 'flip' # name = 'flip' rot=0 name="train" with tf.device('/cpu:1'): trainimg_loc = cwd + '/PASCAL_VOC_2012/%d/r%d/%s_pre' % (self.image_size, rot, name) trainimg = os.listdir(trainimg_loc) trainimg.sort() trainlabel_loc = cwd + '/PASCAL_VOC_2012/%d/r%d/knn/%s/%d' % (self.image_size, rot, name,self.nei_num) trainlabel = os.listdir(trainlabel_loc) trainlabel.sort() traingt_loc = cwd + '/PASCAL_VOC_2012/%d/r%d/gt_%s' % (self.image_size, rot, name) traingt = os.listdir(traingt_loc) traingt.sort() traingt = traingt[:5] ds_trainimg = np.copy(trainimg[:5]) # for display, no shuffle ds_trainlabel = np.copy(trainlabel[:5]) # for display, no shuffle else: flip = 0#randint(0, 1) if flip == 0: name = 'img' else: name = 'flip' with tf.device('/cpu:1'): trainimg_loc = cwd + '/coco/%d/train/%s_pre' % (self.image_size, name) trainimg = os.listdir(trainimg_loc) trainimg.sort() trainlabel_loc = cwd + '/coco/%d/train/knn_%s/%d' % ( self.image_size, name, self.nei_num) trainlabel = os.listdir(trainlabel_loc) trainlabel.sort() traingt_loc = cwd + '/coco/%d/train/gt_%s' % (self.image_size, name) traingt = os.listdir(traingt_loc) traingt.sort() traingt = traingt[:5] ds_trainimg = np.copy(trainimg[:5]) # for display, no shuffle ds_trainlabel = np.copy(trainlabel[:5]) # for display, no shuffle np.random.seed(epoch) np.random.shuffle(trainimg) np.random.seed(epoch) np.random.shuffle(trainlabel) batch_idxs = len(trainimg) // config.batch_size with tf.device('/cpu:1'): for idx in range(0, batch_idxs): batch_files = [] batch_gt = [] for b in range(config.batch_size): batch_files.append(np.load(trainimg_loc + '/' + trainimg[idx * config.batch_size + b])) batch_gt.append(np.load(trainlabel_loc + '/' + trainlabel[idx * config.batch_size + b]).astype(float)) # Update network self.sess.run(optim, feed_dict={self.input: batch_files, self.gt: batch_gt, self.is_training: True}) testsize = 64 np.random.seed(epoch) np.random.shuffle(testimg) np.random.seed(epoch) np.random.shuffle(testlabel) batch_files = [] batch_label = [] test_batch_files = [] test_batch_label = [] with tf.device('/cpu:1'): for b in range(testsize): batch_files.append(np.load(trainimg_loc + '/' + trainimg[b])) batch_label.append(np.load(trainlabel_loc + '/' + trainlabel[b]).astype(float)) test_batch_files.append(np.load(testimg_loc + '/' + testimg[b])) test_batch_label.append(np.load(testlabel_loc + '/' + testlabel[b]).astype(float)) self.display(epoch, config.epoch, batch_files, batch_label, test_batch_files, test_batch_label, 1, self.counter) self.counter = self.counter + 1 epoch_time = time.time() - start_epoch print("Epoch run time : %.9f"%(epoch_time)) if np.mod(epoch, 10) == 1: # Visualization train_piece = [] train_piece_label = [] train_piece_gt = [] test_piece = [] test_piece_label = [] test_piece_gt = [] for b in range(5): train_piece.append(np.load(trainimg_loc + '/' + ds_trainimg[b])) train_piece_label.append(np.load(trainlabel_loc + '/' + ds_trainlabel[b]).astype(float)) train_piece_gt.append(np.load(traingt_loc + '/' + traingt[b]).astype(float)) test_piece.append(np.load(testimg_loc + '/' + ds_testimg[b])) test_piece_label.append(np.load(testlabel_loc + '/' + ds_testlabel[b]).astype(float)) test_piece_gt.append(np.load(testgt_loc + '/' + testgt[b]).astype(float)) det_vi_train = self.sess.run(self.det_test, feed_dict={self.input: train_piece, self.is_training: False}) save_train = config.checkpoint_dir + "/train" if not os.path.exists(save_train): os.makedirs(save_train) vi_train = Visualize(train_piece_gt, train_piece_label, det_vi_train, self.neighbor, save_train, epoch) vi_train.plot() det_vi_test = self.sess.run(self.det_test, feed_dict={self.input: test_piece, self.is_training: False}) save_test = config.checkpoint_dir + "/test" if not os.path.exists(save_test): os.makedirs(save_test) vi_test = Visualize(test_piece_gt, test_piece_label, det_vi_test, self.neighbor, save_test, epoch) vi_test.plot() self.save(config.checkpoint_dir, epoch, config)
# viz_test.py # Andrew Kramer # testing program for visualize.py import numpy as np from visualize import Visualize import argparse if __name__ == "__main__": parser = argparse.ArgumentParser(description='results visualizer options') parser.add_argument('--index', type=int, required=True, help='index of image pair to use') parser.add_argument('--labels', type=str, default='labels.txt', help='file name for labels') parser.add_argument('--pred', type=str, default='pred.txt', help='file name for predictions') args = parser.parse_args() pred = np.loadtxt(args.pred, delimiter=' ') labels = np.loadtxt(args.labels, delimiter=' ') viz = Visualize(labels, pred) viz.plot(args.index)