def get_tfboard_writer(save_dir, model_name, dataset_name, save_time = None): if save_time == None: save_time = time_stamp tfboard_path = save_dir + 'TensorBoard/'+'_'.join([model_name, dataset_name, save_time]) make_savedir(tfboard_path) writer = SummaryWriter(tfboard_path) return writer, tfboard_path
def evaluation(sess, args, config): #from scipy.misc import imread, imsave #imsrc = imread('front_215.jpeg') #imst = imread('left_92.jpeg') #imres = hist_matching(imsrc, imst) #imsave('matching.jpeg',imres) model_type = config.get('config', 'experiment') base_dir = os.path.expanduser(config.get('config', 'basedir')) log_dir = os.path.join(base_dir, config.get('config', 'logdir')) tfrecord_dir = os.path.join(base_dir, config.get(model_type, 'tfrecord')) ckpt_dir = os.path.join(log_dir, utils.make_savedir(config)) print(ckpt_dir) model_file = importlib.import_module(model_type) model = getattr(model_file, 'model') da_model = model(args, config) get_batches = getattr(dataset_utils, model_type) # Get test batches with tf.name_scope('batch'): source_image_batch, source_label_batch, source_measure_batch, source_command_batch = get_batches('source', 'test', tfrecord_dir, batch_size=args.batch_size, config=config) target_image_batch, target_label_batch, target_measure_batch, target_command_batch = get_batches('target','test',tfrecord_dir, batch_size=args.batch_size, config=config) # Call model da_model(source_image_batch, target_image_batch, source_measure_batch) with tf.name_scope('objective'): da_model.create_objective(source_label_batch, source_command_batch) command_l = tf.nn.softmax(da_model.end[0]) command_s = tf.nn.softmax(da_model.end[1]) command_r = tf.nn.softmax(da_model.end[2]) command_l_one_hot = slim.one_hot_encoding(tf.argmax(command_l, 1), args.num_label) command_s_one_hot = slim.one_hot_encoding(tf.argmax(command_s, 1), args.num_label) command_r_one_hot = slim.one_hot_encoding(tf.argmax(command_r, 1), args.num_label) labels = slim.one_hot_encoding(tf.argmax(target_label_batch, 1), args.num_label) command_l_labels = tf.reduce_sum(tf.expand_dims(target_command_batch[:,0], 1) * labels, 0) command_s_labels = tf.reduce_sum(tf.expand_dims(target_command_batch[:,1], 1) * labels, 0) command_r_labels = tf.reduce_sum(tf.expand_dims(target_command_batch[:,2], 1) * labels, 0) command_l_correct = tf.reduce_sum(tf.expand_dims(target_command_batch[:,0], 1) * tf.multiply(labels, command_l_one_hot), 0) command_s_correct = tf.reduce_sum(tf.expand_dims(target_command_batch[:,1], 1) * tf.multiply(labels, command_s_one_hot), 0) command_r_correct = tf.reduce_sum(tf.expand_dims(target_command_batch[:,2], 1) * tf.multiply(labels, command_r_one_hot), 0) sess.run(tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())) # Load checkpoint saver = tf.train.Saver() ckpt = tf.train.get_checkpoint_state(ckpt_dir) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) tf.logging.info('Checkpoint loaded from %s' % ckpt_dir) else: tf.logging.warn('Checkpoint not loaded') return coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess, coord) try: left_label_num = np.zeros(5) str_label_num = np.zeros(5) right_label_num= np.zeros(5) left_acc = np.zeros(5) str_acc = np.zeros(5) right_acc = np.zeros(5) for idx in range(args.num_eval): sim, tim, s2tim, t2sim = sess.run([da_model.source_concat,da_model.target_concat, da_model.g_s2t, da_model.g_t2s]) s2tim_mat = hist_matching(s2tim, tim) t2sim_mat = hist_matching(t2sim, sim) imsave('s2tim.png',s2tim[0,:,:,:]) imsave('s2tim_mat.png',np.uint8((s2tim_mat[0,:,:,:]+1)/2*255)) imsave('t2sim.png',t2sim[0,:,:,:]) imsave('t2sim_mat.png',t2sim_mat[0,:,:,:]) command, steer_label, np_im, steer_pred_l, steer_pred_s, steer_pred_r, cmd_l_label_num, cmd_s_label_num, cmd_r_label_num, cmd_l_correct, cmd_s_correct, cmd_r_correct \ = sess.run([target_command_batch, target_label_batch, target_image_batch, command_l, command_s, command_r, command_l_labels, command_s_labels, command_r_labels, command_l_correct, command_s_correct, command_r_correct]) if args.print_info: print('COMMAND LABEL') print(command) print('STEER') print(steer_label) print('-'*20) accuracy(steer_pred_l, steer_pred_s, steer_pred_r, args.batch_size, command, args.num_label) print(cmd_l_label_num) print(cmd_s_label_num) print(cmd_r_label_num) print('-'*10) print(cmd_l_correct) print(cmd_s_correct) print(cmd_r_correct) left_label_num += cmd_l_label_num str_label_num += cmd_s_label_num right_label_num += cmd_r_label_num left_acc += cmd_l_correct str_acc += cmd_s_correct right_acc += cmd_r_correct for i in range(args.num_label): left_acc[i] = left_acc[i] / left_label_num[i] str_acc[i] = str_acc[i] / str_label_num[i] right_acc[i] = right_acc[i] / right_label_num[i] print('LEFT COMMAND ACCURACY') print(left_acc) print('STRAIGHT COMMAND ACCURACY') print(str_acc) print('RIGHT COMMAND ACCURACY') print(right_acc) finally: coord.request_stop() coord.join(threads)
def train(sess, args, config): model_type = config.get('config', 'experiment') base_dir = os.path.expanduser(config.get('config', 'basedir')) tfrecord_dir = os.path.join(base_dir, config.get(model_type, 'tfrecord')) log_dir = os.path.join(base_dir, config.get('config', 'logdir')) adversarial_mode = config.get('config', 'mode') whether_noise = config.getboolean('generator', 'noise') t2s_task = config.getboolean('config', 't2s_task') noise_dim = config.getint('generator', 'noise_dim') source_only = config.getboolean('config', 'source_only') s2t_adversarial_weight = config.getfloat(model_type, 's2t_adversarial_weight') t2s_adversarial_weight = config.getfloat(model_type, 't2s_adversarial_weight') s2t_cyclic_weight = config.getfloat(model_type, 's2t_cyclic_weight') t2s_cyclic_weight = config.getfloat(model_type, 't2s_cyclic_weight') s2t_task_weight = config.getfloat(model_type, 'task_weight') t2s_task_weight = config.getfloat(model_type, 't2s_task_weight') s2t_style_weight = config.getfloat(model_type, 's2t_style_weight') t2s_style_weight = config.getfloat(model_type, 't2s_style_weight') discriminator_step = config.getint(model_type, 'discriminator_step') generator_step = config.getint(model_type, 'generator_step') save_dir = os.path.join(log_dir, utils.make_savedir(config)) # save_dir = os.path.join(log_dir, config.get('config', 'savedir')) if args.delete and os.path.exists(save_dir): shutil.rmtree(save_dir) os.makedirs(save_dir, exist_ok=True) model_path = importlib.import_module(model_type) model = getattr(model_path, 'model') da_model = model(args, config) writer = tf.summary.FileWriter(save_dir, sess.graph) global_step = tf.train.get_or_create_global_step() get_batches = getattr(dataset_utils, model_type) tf.logging.info('Training %s with %s' % (model_type, adversarial_mode)) if model_type == 'da_cil': with tf.name_scope(model_type + '_batches'): source_image_batch, source_label_batch, source_measure_batch, source_command_batch = get_batches('source', 'train', tfrecord_dir, batch_size=args.batch_size, config=config, args=args) if source_only: da_model(source_image_batch, None, source_measure_batch) else: target_image_batch, _, _, _ = get_batches('target', 'train', tfrecord_dir, batch_size=args.batch_size, config=config, args=args) da_model(source_image_batch, target_image_batch, source_measure_batch) with tf.name_scope(model_type + '_objectives'): da_model.create_objective(source_label_batch, source_command_batch, adversarial_mode) if source_only: discriminator_loss = da_model.task_loss da_model.summary['discriminator_loss'] = discriminator_loss else: generator_loss = s2t_cyclic_weight * da_model.s2t_cyclic_loss + t2s_cyclic_weight * da_model.t2s_cyclic_loss + da_model.s2t_adversarial_loss[0] + da_model.t2s_adversarial_loss[0] generator_loss += s2t_style_weight * da_model.s2t_style_loss + t2s_style_weight * da_model.t2s_style_loss da_model.summary['generator_loss'] = generator_loss discriminator_loss = s2t_adversarial_weight * da_model.s2t_adversarial_loss[1] + t2s_adversarial_weight * da_model.t2s_adversarial_loss[1] + s2t_task_weight * da_model.task_loss if t2s_task: discriminator_loss += t2s_task_weight * da_model.t2s_task_loss da_model.summary['discriminator_loss'] = discriminator_loss elif model_type == 'pixel_da': with tf.name_scope(model_type + '_batches'): source_image_batch, source_label_batch = get_batches('source', 'train', tfrecord_dir, batch_size=args.batch_size, config=config) mask_image_batch = source_image_batch[:,:,:,3] source_image_batch = source_image_batch[:,:,:,:3] if config.getboolean(model_type, 'input_mask'): tf.logging.info('Using masked input') mask_images = tf.to_float(tf.greater(mask_image_batch, 0.9)) source_image_batch = tf.multiply(source_image_batch, tf.tile(tf.expand_dims(mask_images, 3), [1,1,1,3])) # Label is already an 1-hot labels, but we expect categorical source_label_max_batch = tf.argmax(source_label_batch, 1) source_lateral_label_batch = (source_label_max_batch % 9) / 3 source_head_label_batch = source_label_max_batch % 3 target_image_batch, _ = get_batches('target', 'train', tfrecord_dir, batch_size=args.batch_size, config=config) da_model(source_image_batch, target_image_batch) with tf.name_scope(model_type + '_objectives'): da_model.create_objective(source_head_label_batch, source_lateral_label_batch, adversarial_mode) generator_loss = s2t_cyclic_weight * da_model.s2t_cyclic_loss + t2s_cyclic_weight * da_model.t2s_cyclic_loss + da_model.s2t_adversarial_loss[0] + da_model.t2s_adversarial_loss[0] generator_loss += s2t_style_weight * da_model.s2t_style_loss + t2s_style_weight * da_model.t2s_style_loss da_model.summary['generator_loss'] = generator_loss discriminator_loss = s2t_adversarial_weight * da_model.s2t_adversarial_loss[1] + t2s_adversarial_weight * da_model.t2s_adversarial_loss[1] + s2t_task_weight * da_model.transferred_task_loss if t2s_task: discriminator_loss += t2s_task_weight * da_model.t2s_task_loss da_model.summary['discriminator_loss'] = discriminator_loss else: raise Exception('Not supported model') with tf.name_scope('optimizer'): tf.logging.info('Getting optimizer') if args.lr_decay: decay_steps = config.getint('optimizer', 'decay_steps') decay_rate = config.getfloat('optimizer', 'decay_rate') learning_rate = tf.train.exponential_decay(args.learning_rate, global_step, decay_steps, decay_rate, staircase=True) else: learning_rate = args.learning_rate if not source_only: g_optimizer = _get_optimizer(config, args.optimizer)(learning_rate) g_optim = _gradient_clip(name='generator', optimizer=g_optimizer, loss=generator_loss, global_steps=global_step, clip_norm=args.clip_norm) d_optimizer = _get_optimizer(config, args.optimizer)(learning_rate) d_optim = _gradient_clip(name='discriminator', optimizer=d_optimizer, loss=discriminator_loss, global_steps=global_step, clip_norm=args.clip_norm) if not source_only: generator_summary, discriminator_summary = utils.summarize(da_model.summary, t2s_task) utils.config_summary(save_dir, s2t_adversarial_weight, t2s_adversarial_weight, s2t_cyclic_weight, t2s_cyclic_weight, s2t_task_weight, t2s_task_weight, discriminator_step, generator_step, adversarial_mode, whether_noise, noise_dim, s2t_style_weight, t2s_style_weight) else: discriminator_summary = utils.summarize(da_model.summary, t2s_task, source_only) saver = tf.train.Saver(max_to_keep=5) sess.run(tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())) if args.load_ckpt: ckpt = tf.train.get_checkpoint_state(save_dir) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess, coord) try: for iter_count in range(args.max_iter): # Update discriminator for disc_iter in range(discriminator_step): d_loss, _, steps = sess.run([discriminator_loss, d_optim, global_step]) if not source_only and adversarial_mode == 'FISHER': _, _ = sess.run([da_model.s2t_adversarial_loss[-1], da_model.t2s_adversarial_loss[-1]]) #writer.add_summary(disc_sum, steps) tf.logging.info('Step %d: Discriminator loss=%.5f', steps, d_loss) if not source_only: for gen_iter in range(generator_step): g_loss, _, steps = sess.run([generator_loss, g_optim, global_step]) #writer.add_summary(gen_sum, steps) tf.logging.info('Step %d: Generator loss=%.5f', steps, g_loss) if (iter_count+1) % args.save_interval == 0: saver.save(sess, os.path.join(save_dir, model_type), global_step=(iter_count+1)) tf.logging.info('Checkpoint save') if (iter_count+1) % args.summary_interval == 0: if not source_only: disc_sum, gen_sum = sess.run([discriminator_summary, generator_summary]) writer.add_summary(gen_sum, steps) else: disc_sum = sess.run(discriminator_summary) writer.add_summary(disc_sum, steps) tf.logging.info('Summary at %d step' % (iter_count+1)) except tf.errors.OutOfRangeError: print('Epoch limited') except KeyboardInterrupt: print('End training') finally: coord.request_stop() coord.join(threads)
def evaluation(sess, args, config): model_type = config.get('config', 'experiment') base_dir = os.path.expanduser(config.get('config', 'basedir')) log_dir = os.path.join(base_dir, config.get('config', 'logdir')) tfrecord_dir = os.path.join(base_dir, config.get(model_type, 'tfrecord')) ckpt_dir = os.path.join(log_dir, utils.make_savedir(config)) print(ckpt_dir) model_file = importlib.import_module(model_type) model = getattr(model_file, 'model') da_model = model(args, config) # Get test batches with tf.name_scope('batch'): source_image_batch, source_label_batch = dataset_utils.get_batches( 'source', 'test', tfrecord_dir, batch_size=args.batch_size, config=config) mask_image_batch = source_image_batch[:, :, :, 3] source_image_batch = source_image_batch[:, :, :, :3] if config.getboolean('config', 'input_mask'): # 0/1 -> -1/1, -1 would be 0 mask_images = tf.to_float(tf.greater(mask_image_batch, 0.99)) source_image_batch = tf.multiply( source_image_batch, tf.tile(tf.expand_dims(mask_images, 3), [1, 1, 1, 3])) target_image_batch, target_label_batch = dataset_utils.get_batches( 'target', 'test', tfrecord_dir, batch_size=args.batch_size, config=config) target_label_max_batch = tf.argmax(target_label_batch, 1) target_lateral_label_batch = (target_label_max_batch % 9) / 3 target_head_label_batch = target_label_max_batch % 3 # Call model da_model(source_image_batch, target_image_batch) #sess.run(tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())) # Load checkpoint saver = tf.train.Saver() ckpt = tf.train.get_checkpoint_state(ckpt_dir) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) tf.logging.info('Checkpoint loaded from %s' % ckpt_dir) else: tf.logging.warn('Checkpoint not loaded') return target_lateral_one_hot_label = one_hot(target_lateral_label_batch) target_head_one_hot_label = one_hot(target_head_label_batch) lateral_label_num = tf.reduce_sum(target_lateral_one_hot_label, 0) head_label_num = tf.reduce_sum(target_head_one_hot_label, 0) lateral_command = get_command(da_model.target_lateral_logits) lateral_cmd_label = tf.multiply(target_lateral_one_hot_label, lateral_command) head_command = get_command(da_model.target_head_logits) head_cmd_label = tf.multiply(target_head_one_hot_label, head_command) model_lateral_one_hot = one_hot( tf.argmax(da_model.target_lateral_logits, 1)) model_head_one_hot = one_hot(tf.argmax(da_model.target_head_logits, 1)) lateral_correct_one = tf.reduce_sum( tf.multiply(model_lateral_one_hot, target_lateral_one_hot_label), 0) head_correct_one = tf.reduce_sum( tf.multiply(model_head_one_hot, target_head_one_hot_label), 0) target_lateral_one_hot_label = tf.expand_dims(target_lateral_one_hot_label, 2) target_head_one_hot_label = tf.expand_dims(target_head_one_hot_label, 2) lateral_confusion_table = tf.multiply( target_lateral_one_hot_label, tf.expand_dims(tf.nn.softmax(da_model.target_lateral_logits), 1)) head_confusion_table = tf.multiply( target_head_one_hot_label, tf.expand_dims(tf.nn.softmax(da_model.target_head_logits), 1)) #target_high_pass = op.neg_gaussian_filter(target_image_batch) #trans_high_pass = op.neg_gaussian_filter(da_model.g_s2t) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess, coord) try: #target_head_prob, target_lateral_prob = sess.run([tf.nn.softmax(da_model.target_head_logits), tf.nn.softmax(da_model.target_lateral_logits)]) np_lateral_confusion = np.zeros((3, 3)) np_head_confusion = np.zeros((3, 3)) np_lateral_num = np.zeros(3) np_head_num = np.zeros(3) np_lateral_acc = np.zeros(3) np_head_acc = np.zeros(3) np_lateral_cmd = np.zeros(3) np_head_cmd = np.zeros(3) np_l_conf_array = np.zeros((1, 3, 3)) np_h_conf_array = np.zeros((1, 3, 3)) test_batch_num = int(math.floor(20000 / args.batch_size)) for idx in range(test_batch_num): if idx == 0: #source, transferred_im, np_tar_high_pass, np_trans_high_pass = sess.run([source_image_batch, da_model.g_s2t, target_high_pass, trans_high_pass]) source, transferred_im = sess.run( [source_image_batch, da_model.g_s2t]) import matplotlib.pyplot as plt import scipy.misc im_dir = os.path.join(ckpt_dir, 'im') if not os.path.exists(im_dir): os.mkdir(im_dir) for im_idx in range(4): scipy.misc.imsave( os.path.join(im_dir, 'source%d.png' % im_idx), source[im_idx, :, :, :]) scipy.misc.imsave( os.path.join(im_dir, 'trans%d.png' % im_idx), transferred_im[im_idx, :, :, :]) #scipy.misc.imsave(os.path.join(im_dir,'trans_high%d.png' % im_idx), np.squeeze(np_trans_high_pass[im_idx,:,:,:])) #scipy.misc.imsave(os.path.join(im_dir,'tar_high%d.png' % im_idx), np.squeeze(np_tar_high_pass[im_idx])) t1, t2, x1, x2, y1, y2, z1, z2, a1, a2 = sess.run([ lateral_confusion_table, head_confusion_table, tf.reduce_sum(lateral_confusion_table, 0), tf.reduce_sum(head_confusion_table, 0), lateral_label_num, head_label_num, lateral_correct_one, head_correct_one, tf.reduce_sum(lateral_cmd_label, 0), tf.reduce_sum(head_cmd_label, 0) ]) np_lateral_confusion = np_lateral_confusion + x1 np_head_confusion = np_head_confusion + x2 np_lateral_num = np_lateral_num + y1 np_head_num = np_head_num + y2 np_lateral_acc = np_lateral_acc + z1 np_head_acc = np_head_acc + z2 np_lateral_cmd += a1 np_head_cmd += a2 np_l_conf_array = np.concatenate([np_l_conf_array, t1], 0) np_h_conf_array = np.concatenate([np_h_conf_array, t2], 0) printProgress(idx, test_batch_num, 'Progress', 'Complete', 1, 50) for i in range(3): np_lateral_confusion[i, :] /= np_lateral_num[i] np_head_confusion[i, :] /= np_head_num[i] np_lateral_acc[i] /= np_lateral_num[i] np_head_acc[i] /= np_head_num[i] np_lateral_cmd[i] /= np_lateral_num[i] np_head_cmd[i] /= np_head_num[i] print('\nlateral_confusion') print(np_lateral_confusion) print('head_confusion') print(np_head_confusion) print('lateral-accuracy') print(np_lateral_acc) print('head-accuracy') print(np_head_acc) print('lateral cmd') print(np_lateral_cmd) print('head cmd') print(np_head_cmd) for i in range(3): y = np_l_conf_array[:, i, 2] - np_l_conf_array[:, i, 0] z = np_l_conf_array[:, i, 0] print('when GT lateral is %d command mean : %.4f std : %.4f' % (i, np.mean(y[z != 0]), np.std(y[z != 0]))) y = np_h_conf_array[:, i, 2] - np_h_conf_array[:, i, 0] z = np_h_conf_array[:, i, 0] print('when GT head is %d command mean : %.4f std : %.4f' % (i, np.mean(y[z != 0]), np.std(y[z != 0]))) finally: coord.request_stop() coord.join(threads)
def train_model(config, save_dir): model_path = save_dir + 'Model/' make_savedir(model_path) #load dataset dataloader, len_dataset, dataset_name = load_dataloader(config, save_dir) in_ch = dataloader['train'].dataset[0][0].shape[0] classes = dataloader['train'].dataset.classes width = dataloader['train'].dataset[0][0].shape[-1] phases = dataloader.keys() num_classes = len(classes) #load model for key in vars(config).keys(): if ('model' in key) and vars(config)[key] != None: model_name = vars(config)[key] break model = get_model(model_name, in_ch, num_classes, config.preTrain) model.classes = classes model, device, parallel = config_device(config, model) criterion = nn.CrossEntropyLoss() optimizer = get_optim(config, model) if config.save_best: best_model_wts = copy.deepcopy(model.state_dict()) if config.tfboard: writer, tfboard_path = get_tfboard_writer(save_dir, model_name, dataset_name) images, _ = next(iter(dataloader['train'])) if parallel: writer.add_graph(model.module, images.to(device)) else: writer.add_graph(model, images.to(device)) num_epochs = config.epoch best_acc = 0 for epoch in range(num_epochs): epoch_start_time = time.time() print('\n\nEpoch {}/{}'.format(epoch+1, num_epochs)) print('-' * 60) for phase in phases: if phase == 'train': model.train() elif phase =='valid': model.eval() else: break running_loss = 0.0 running_corrects = 0 for inputs, labels in dataloader[phase]: if len(labels) == 1: continue inputs = inputs.to(device) labels = labels.to(device) optimizer.zero_grad() with torch.set_grad_enabled(phase == 'train'): outputs = model(inputs) _, preds = torch.max(outputs, 1) loss = criterion(outputs, labels) if phase == 'train': loss.backward() optimizer.step() running_loss = loss.item() *inputs.size(0) running_corrects += torch.sum(preds == labels.data) epoch_loss = running_loss / len_dataset[phase] epoch_acc = running_corrects.double() / len_dataset[phase] print('<{}>\t\tLoss : {:.4f}\t\tAcc : {:.4f}'.format( phase, epoch_loss, epoch_acc)) if config.tfboard: writer.add_scalar(phase+'_loss', epoch_loss, global_step=epoch+1) writer.add_scalar(phase+'_acc', epoch_acc, global_step=epoch+1) if config.tfboard: writer.add_figure('valid : predictions vs. actuals', plot_classes_preds(model, inputs, classes, labels), global_step=epoch+1) if epoch == 0: tb = program.TensorBoard() tb.configure(argv=[None, '--logdir', tfboard_path]) url = tb.launch() webbrowser.open(url) time_elapsed = time.time()-epoch_start_time print('elapsed time : {:.0f}m {:.0f}s'.format( time_elapsed//60, time_elapsed%60)) if epoch_acc > best_acc: best_acc = epoch_acc if config.save_best: best_model_wts = copy.deepcopy(model.state_dict()) else: path = model_path+'_'.join([model_name, str(in_ch), str(num_classes),dataset_name, '{:.4f}'.format(best_acc.item()),time_stamp]) + '.pth' print('save ', path.split('/')[-1]) if parallel : # torch.save(model.module.state_dict(), path) torch.save(model.module, path) else: # torch.save(model.state_dict(), path) torch.save(model, path) print("Best valid Acc: {:4f}".format(best_acc)) print('-' * 60) if config.save_best: model.load_state_dict(best_model_wts) path = model_path+'_'.join([model_name, str(in_ch), str(num_classes), str(width), dataset_name, '{:.4f}'.format(best_acc.item()),time_stamp]) + '.pth' print('save ', path) print('save ', path.split('/')[-1]) if parallel : # torch.save(model.module.state_dict(), path) torch.save(model.module, path) else: # torch.save(model.state_dict(), path) torch.save(model, path) if 'test' in phases: test_model(model, dataloader['test'], classes, device) if config.tfboard: writer.close() print('\nComplete training\n')
def evaluation(sess, args, config): model_type = config.get('config', 'experiment') base_dir = os.path.expanduser(config.get('config', 'basedir')) log_dir = os.path.join(base_dir, config.get('config', 'logdir')) tfrecord_dir = os.path.join(base_dir, config.get(model_type, 'tfrecord')) ckpt_dir = os.path.join(log_dir, utils.make_savedir(config)) print(ckpt_dir) model_file = importlib.import_module(model_type) model = getattr(model_file, 'model') da_model = model(args, config) get_batches = getattr(dataset_utils, model_type) # Get test batches with tf.name_scope('batch'): source_image_batch, source_label_batch, source_measure_batch, source_command_batch = get_batches('source', 'test', tfrecord_dir, batch_size=args.batch_size, config=config) # Call model da_model(source_image_batch, None, source_measure_batch) with tf.name_scope('objective'): da_model.create_objective(source_label_batch, source_command_batch) ''' Metrics ''' # tf.nn.softmax's default axis is -1 (last dimension) command_l = tf.nn.softmax(da_model.end[0]) command_s = tf.nn.softmax(da_model.end[1]) command_r = tf.nn.softmax(da_model.end[2]) command_l_one_hot = slim.one_hot_encoding(tf.argmax(command_l, 1), args.num_label) command_s_one_hot = slim.one_hot_encoding(tf.argmax(command_s, 1), args.num_label) command_r_one_hot = slim.one_hot_encoding(tf.argmax(command_r, 1), args.num_label) # Label one-hot labels = slim.one_hot_encoding(tf.argmax(source_label_batch, 1), args.num_label) # Getting accuracy for each class command_l_labels = tf.reduce_sum(tf.expand_dims(source_command_batch[:,0], 1) * labels, 0) command_s_labels = tf.reduce_sum(tf.expand_dims(source_command_batch[:,1], 1) * labels, 0) command_r_labels = tf.reduce_sum(tf.expand_dims(source_command_batch[:,2], 1) * labels, 0) command_l_correct = tf.reduce_sum(tf.expand_dims(source_command_batch[:,0], 1) * tf.multiply(labels, command_l_one_hot), 0) command_s_correct = tf.reduce_sum(tf.expand_dims(source_command_batch[:,1], 1) * tf.multiply(labels, command_s_one_hot), 0) command_r_correct = tf.reduce_sum(tf.expand_dims(source_command_batch[:,2], 1) * tf.multiply(labels, command_r_one_hot), 0) sess.run(tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())) # Load checkpoint saver = tf.train.Saver() ckpt = tf.train.get_checkpoint_state(ckpt_dir) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) tf.logging.info('Checkpoint loaded from %s' % ckpt_dir) else: tf.logging.warn('Checkpoint not loaded') return coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess, coord) try: left_label_num = np.zeros(5) str_label_num = np.zeros(5) right_label_num = np.zeros(5) left_acc = np.zeros(5) str_acc = np.zeros(5) right_acc = np.zeros(5) for idx in range(args.num_eval): command, steer_label, np_im, steer_pred_l, steer_pred_s, steer_pred_r, cmd_l_label_num, cmd_s_label_num, cmd_r_label_num, cmd_l_correct, cmd_s_correct, cmd_r_correct \ = sess.run([source_command_batch, source_label_batch, source_image_batch, command_l, command_s, command_r, command_l_labels, command_s_labels, command_r_labels, command_l_correct, command_s_correct, command_r_correct]) if args.print_info: print('COMMAND LABEL') print(command) print('STEER') print(steer_label) print('-'*20) accuracy(steer_pred_l, steer_pred_s, steer_pred_r, args.batch_size, command, args.num_label) print(cmd_l_label_num) print(cmd_s_label_num) print(cmd_r_label_num) print('-'*10) print(cmd_l_correct) print(cmd_s_correct) print(cmd_r_correct) left_label_num += cmd_l_label_num str_label_num += cmd_s_label_num right_label_num += cmd_r_label_num left_acc += cmd_l_correct str_acc += cmd_s_correct right_acc += cmd_r_correct for i in range(args.num_label): left_acc[i] = left_acc[i] / left_label_num[i] str_acc[i] = str_acc[i] / str_label_num[i] right_acc[i] = right_acc[i] / right_label_num[i] print('LEFT COMMAND ACCURACY') print(left_acc) print('STRAIGHT COMMAND ACCURACY') print(str_acc) print('RIGHT COMMAND ACCURACY') print(right_acc) finally: coord.request_stop() coord.join(threads)