Пример #1
0
def get_tfboard_writer(save_dir, model_name, dataset_name, save_time = None):
    if save_time == None:
        save_time = time_stamp
    tfboard_path = save_dir + 'TensorBoard/'+'_'.join([model_name, dataset_name, save_time])
    make_savedir(tfboard_path)
    writer = SummaryWriter(tfboard_path)
    return writer, tfboard_path
Пример #2
0
def evaluation(sess, args, config):

    #from scipy.misc import imread, imsave
    #imsrc = imread('front_215.jpeg')
    #imst  = imread('left_92.jpeg')

    #imres = hist_matching(imsrc, imst)
    #imsave('matching.jpeg',imres) 

    model_type = config.get('config', 'experiment')
    base_dir = os.path.expanduser(config.get('config', 'basedir'))
    log_dir = os.path.join(base_dir, config.get('config', 'logdir'))
    tfrecord_dir = os.path.join(base_dir, config.get(model_type, 'tfrecord'))
    ckpt_dir = os.path.join(log_dir, utils.make_savedir(config))
    print(ckpt_dir)
    
    model_file = importlib.import_module(model_type)
    model = getattr(model_file, 'model')

    da_model = model(args, config)

    get_batches = getattr(dataset_utils, model_type)
    
    # Get test batches   
    with tf.name_scope('batch'):
        source_image_batch, source_label_batch, source_measure_batch, source_command_batch = get_batches('source', 'test', tfrecord_dir, batch_size=args.batch_size, config=config)
        target_image_batch, target_label_batch, target_measure_batch, target_command_batch = get_batches('target','test',tfrecord_dir, batch_size=args.batch_size, config=config)

    # Call model
    da_model(source_image_batch, target_image_batch, source_measure_batch)

    with tf.name_scope('objective'):
        da_model.create_objective(source_label_batch, source_command_batch)

    command_l = tf.nn.softmax(da_model.end[0])
    command_s = tf.nn.softmax(da_model.end[1])
    command_r = tf.nn.softmax(da_model.end[2])
    command_l_one_hot = slim.one_hot_encoding(tf.argmax(command_l, 1), args.num_label)
    command_s_one_hot = slim.one_hot_encoding(tf.argmax(command_s, 1), args.num_label)
    command_r_one_hot = slim.one_hot_encoding(tf.argmax(command_r, 1), args.num_label)

    labels = slim.one_hot_encoding(tf.argmax(target_label_batch, 1), args.num_label)
    
    command_l_labels = tf.reduce_sum(tf.expand_dims(target_command_batch[:,0], 1) * labels, 0)
    command_s_labels = tf.reduce_sum(tf.expand_dims(target_command_batch[:,1], 1) * labels, 0)
    command_r_labels = tf.reduce_sum(tf.expand_dims(target_command_batch[:,2], 1) * labels, 0)
    
    command_l_correct = tf.reduce_sum(tf.expand_dims(target_command_batch[:,0], 1) * tf.multiply(labels, command_l_one_hot), 0)
    command_s_correct = tf.reduce_sum(tf.expand_dims(target_command_batch[:,1], 1) * tf.multiply(labels, command_s_one_hot), 0)
    command_r_correct = tf.reduce_sum(tf.expand_dims(target_command_batch[:,2], 1) * tf.multiply(labels, command_r_one_hot), 0)   

    sess.run(tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()))
    # Load checkpoint 
    saver = tf.train.Saver()
    ckpt = tf.train.get_checkpoint_state(ckpt_dir)
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)
        tf.logging.info('Checkpoint loaded from %s' % ckpt_dir)
    else:
        tf.logging.warn('Checkpoint not loaded')
        return

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess, coord)

    try:
        left_label_num = np.zeros(5)
        str_label_num  = np.zeros(5)
        right_label_num= np.zeros(5)
        left_acc       = np.zeros(5)
        str_acc        = np.zeros(5)
        right_acc      = np.zeros(5)   
        for idx in range(args.num_eval):
            sim, tim, s2tim, t2sim = sess.run([da_model.source_concat,da_model.target_concat, da_model.g_s2t, da_model.g_t2s])
            s2tim_mat = hist_matching(s2tim, tim)
            t2sim_mat = hist_matching(t2sim, sim)
            imsave('s2tim.png',s2tim[0,:,:,:])
            imsave('s2tim_mat.png',np.uint8((s2tim_mat[0,:,:,:]+1)/2*255))
            imsave('t2sim.png',t2sim[0,:,:,:])
            imsave('t2sim_mat.png',t2sim_mat[0,:,:,:])
            command, steer_label, np_im, steer_pred_l, steer_pred_s, steer_pred_r, cmd_l_label_num, cmd_s_label_num, cmd_r_label_num, cmd_l_correct, cmd_s_correct, cmd_r_correct \
                 = sess.run([target_command_batch, target_label_batch, target_image_batch, command_l, command_s, command_r, command_l_labels, command_s_labels, command_r_labels, command_l_correct, command_s_correct, command_r_correct])

            if args.print_info:
                print('COMMAND LABEL')
                print(command)
                print('STEER')
                print(steer_label)
                print('-'*20)
                accuracy(steer_pred_l, steer_pred_s, steer_pred_r, args.batch_size, command, args.num_label)
                print(cmd_l_label_num)
                print(cmd_s_label_num)
                print(cmd_r_label_num)
                print('-'*10)
                print(cmd_l_correct)
                print(cmd_s_correct)
                print(cmd_r_correct) 

            left_label_num += cmd_l_label_num
            str_label_num += cmd_s_label_num
            right_label_num += cmd_r_label_num
            left_acc += cmd_l_correct
            str_acc += cmd_s_correct
            right_acc += cmd_r_correct

        for i in range(args.num_label):
            left_acc[i] = left_acc[i] / left_label_num[i]
            str_acc[i] = str_acc[i] / str_label_num[i]
            right_acc[i] = right_acc[i] / right_label_num[i]

        print('LEFT COMMAND ACCURACY')
        print(left_acc)
        print('STRAIGHT COMMAND ACCURACY')
        print(str_acc)
        print('RIGHT COMMAND ACCURACY')
        print(right_acc)
    finally:
        coord.request_stop()
        coord.join(threads)
Пример #3
0
def train(sess, args, config):
    model_type = config.get('config', 'experiment')
    base_dir = os.path.expanduser(config.get('config', 'basedir'))
    tfrecord_dir = os.path.join(base_dir, config.get(model_type, 'tfrecord'))
    log_dir = os.path.join(base_dir, config.get('config', 'logdir'))
    adversarial_mode = config.get('config', 'mode')
    whether_noise = config.getboolean('generator', 'noise')
    t2s_task = config.getboolean('config', 't2s_task')
    noise_dim = config.getint('generator', 'noise_dim')

    source_only = config.getboolean('config', 'source_only')
    s2t_adversarial_weight = config.getfloat(model_type, 's2t_adversarial_weight')
    t2s_adversarial_weight = config.getfloat(model_type, 't2s_adversarial_weight')
    s2t_cyclic_weight = config.getfloat(model_type, 's2t_cyclic_weight')
    t2s_cyclic_weight = config.getfloat(model_type, 't2s_cyclic_weight')
    s2t_task_weight = config.getfloat(model_type, 'task_weight')
    t2s_task_weight = config.getfloat(model_type, 't2s_task_weight')
    s2t_style_weight = config.getfloat(model_type, 's2t_style_weight')
    t2s_style_weight = config.getfloat(model_type, 't2s_style_weight')
    discriminator_step = config.getint(model_type, 'discriminator_step')
    generator_step = config.getint(model_type, 'generator_step')
    save_dir = os.path.join(log_dir, utils.make_savedir(config))
#    save_dir = os.path.join(log_dir, config.get('config', 'savedir'))

    if args.delete and os.path.exists(save_dir):
        shutil.rmtree(save_dir)
	
    os.makedirs(save_dir, exist_ok=True)

    model_path = importlib.import_module(model_type)
    model = getattr(model_path, 'model')

    da_model = model(args, config)

    writer = tf.summary.FileWriter(save_dir, sess.graph)
    global_step = tf.train.get_or_create_global_step()

    get_batches = getattr(dataset_utils, model_type)

    tf.logging.info('Training %s with %s' % (model_type, adversarial_mode))
    
    if model_type == 'da_cil':
        with tf.name_scope(model_type + '_batches'):
            source_image_batch, source_label_batch, source_measure_batch, source_command_batch = get_batches('source', 'train', tfrecord_dir, batch_size=args.batch_size, config=config, args=args)
    
        if source_only:
            da_model(source_image_batch, None, source_measure_batch)
        else:       
            target_image_batch, _, _, _ = get_batches('target', 'train', tfrecord_dir, batch_size=args.batch_size, config=config, args=args)
            da_model(source_image_batch, target_image_batch, source_measure_batch)

        with tf.name_scope(model_type + '_objectives'):
            da_model.create_objective(source_label_batch, source_command_batch, adversarial_mode)

            if source_only:
                discriminator_loss = da_model.task_loss
                da_model.summary['discriminator_loss'] = discriminator_loss
            else:
                generator_loss = s2t_cyclic_weight * da_model.s2t_cyclic_loss + t2s_cyclic_weight * da_model.t2s_cyclic_loss + da_model.s2t_adversarial_loss[0] + da_model.t2s_adversarial_loss[0]
                generator_loss += s2t_style_weight * da_model.s2t_style_loss + t2s_style_weight * da_model.t2s_style_loss
                da_model.summary['generator_loss'] = generator_loss

                discriminator_loss = s2t_adversarial_weight * da_model.s2t_adversarial_loss[1] + t2s_adversarial_weight * da_model.t2s_adversarial_loss[1] + s2t_task_weight * da_model.task_loss 
                if t2s_task:
                    discriminator_loss += t2s_task_weight * da_model.t2s_task_loss 
                da_model.summary['discriminator_loss'] = discriminator_loss

    elif model_type == 'pixel_da':
        with tf.name_scope(model_type + '_batches'):
            source_image_batch, source_label_batch = get_batches('source', 'train', tfrecord_dir, batch_size=args.batch_size, config=config)
            mask_image_batch = source_image_batch[:,:,:,3]
            source_image_batch = source_image_batch[:,:,:,:3]
            if config.getboolean(model_type, 'input_mask'):
                tf.logging.info('Using masked input')
                mask_images = tf.to_float(tf.greater(mask_image_batch, 0.9))
                source_image_batch = tf.multiply(source_image_batch, tf.tile(tf.expand_dims(mask_images, 3), [1,1,1,3])) 
                
            # Label is already an 1-hot labels, but we expect categorical
            source_label_max_batch = tf.argmax(source_label_batch, 1)
            source_lateral_label_batch = (source_label_max_batch % 9) / 3
            source_head_label_batch = source_label_max_batch % 3
            
            target_image_batch, _ = get_batches('target', 'train', tfrecord_dir, batch_size=args.batch_size, config=config)

        da_model(source_image_batch, target_image_batch)

        with tf.name_scope(model_type + '_objectives'):
            da_model.create_objective(source_head_label_batch, source_lateral_label_batch, adversarial_mode)

            generator_loss = s2t_cyclic_weight * da_model.s2t_cyclic_loss + t2s_cyclic_weight * da_model.t2s_cyclic_loss + da_model.s2t_adversarial_loss[0] + da_model.t2s_adversarial_loss[0]
            generator_loss += s2t_style_weight * da_model.s2t_style_loss + t2s_style_weight * da_model.t2s_style_loss
            da_model.summary['generator_loss'] = generator_loss

            discriminator_loss = s2t_adversarial_weight * da_model.s2t_adversarial_loss[1] + t2s_adversarial_weight * da_model.t2s_adversarial_loss[1] + s2t_task_weight * da_model.transferred_task_loss 
            if t2s_task:
                discriminator_loss += t2s_task_weight * da_model.t2s_task_loss 
            da_model.summary['discriminator_loss'] = discriminator_loss

    else:
        raise Exception('Not supported model')

    with tf.name_scope('optimizer'):
        tf.logging.info('Getting optimizer')
        if args.lr_decay:
            decay_steps = config.getint('optimizer', 'decay_steps')
            decay_rate = config.getfloat('optimizer', 'decay_rate')
            learning_rate = tf.train.exponential_decay(args.learning_rate, global_step, decay_steps, decay_rate, staircase=True)
        else:
            learning_rate = args.learning_rate

        if not source_only:
            g_optimizer = _get_optimizer(config, args.optimizer)(learning_rate)
            g_optim = _gradient_clip(name='generator', optimizer=g_optimizer, loss=generator_loss, global_steps=global_step, clip_norm=args.clip_norm)
        d_optimizer = _get_optimizer(config, args.optimizer)(learning_rate)
        d_optim = _gradient_clip(name='discriminator', optimizer=d_optimizer, loss=discriminator_loss, global_steps=global_step, clip_norm=args.clip_norm)
    if not source_only:
        generator_summary, discriminator_summary = utils.summarize(da_model.summary, t2s_task) 
        utils.config_summary(save_dir, s2t_adversarial_weight, t2s_adversarial_weight, s2t_cyclic_weight, t2s_cyclic_weight, s2t_task_weight, t2s_task_weight, discriminator_step, generator_step, adversarial_mode, whether_noise, noise_dim, s2t_style_weight, t2s_style_weight)
    else:
        discriminator_summary = utils.summarize(da_model.summary, t2s_task, source_only)


    saver = tf.train.Saver(max_to_keep=5)
    sess.run(tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()))

    if args.load_ckpt:
        ckpt = tf.train.get_checkpoint_state(save_dir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess, coord)

    try:
        for iter_count in range(args.max_iter):
            # Update discriminator
            for disc_iter in range(discriminator_step):
                d_loss, _, steps = sess.run([discriminator_loss, d_optim, global_step])
                if not source_only and adversarial_mode == 'FISHER':
                    _, _ = sess.run([da_model.s2t_adversarial_loss[-1], da_model.t2s_adversarial_loss[-1]])
                #writer.add_summary(disc_sum, steps)
                tf.logging.info('Step %d: Discriminator loss=%.5f', steps, d_loss)

            if not source_only:
                for gen_iter in range(generator_step):
                    g_loss, _, steps = sess.run([generator_loss, g_optim, global_step])
                    #writer.add_summary(gen_sum, steps)
                    tf.logging.info('Step %d: Generator loss=%.5f', steps, g_loss)
            
            if (iter_count+1) % args.save_interval == 0:
                saver.save(sess, os.path.join(save_dir, model_type), global_step=(iter_count+1))
                tf.logging.info('Checkpoint save')

            if (iter_count+1) % args.summary_interval == 0:
                if not source_only:
                    disc_sum, gen_sum = sess.run([discriminator_summary, generator_summary])
                    writer.add_summary(gen_sum, steps)
                else:
                    disc_sum = sess.run(discriminator_summary)
                    
                writer.add_summary(disc_sum, steps)
                tf.logging.info('Summary at %d step' % (iter_count+1))

        
    except tf.errors.OutOfRangeError:
        print('Epoch limited')
    except KeyboardInterrupt:
        print('End training')
    finally:
        coord.request_stop()
        coord.join(threads)      
Пример #4
0
def evaluation(sess, args, config):
    model_type = config.get('config', 'experiment')
    base_dir = os.path.expanduser(config.get('config', 'basedir'))
    log_dir = os.path.join(base_dir, config.get('config', 'logdir'))
    tfrecord_dir = os.path.join(base_dir, config.get(model_type, 'tfrecord'))
    ckpt_dir = os.path.join(log_dir, utils.make_savedir(config))
    print(ckpt_dir)

    model_file = importlib.import_module(model_type)
    model = getattr(model_file, 'model')

    da_model = model(args, config)

    # Get test batches
    with tf.name_scope('batch'):
        source_image_batch, source_label_batch = dataset_utils.get_batches(
            'source',
            'test',
            tfrecord_dir,
            batch_size=args.batch_size,
            config=config)
        mask_image_batch = source_image_batch[:, :, :, 3]
        source_image_batch = source_image_batch[:, :, :, :3]
        if config.getboolean('config', 'input_mask'):
            # 0/1 -> -1/1, -1 would be 0
            mask_images = tf.to_float(tf.greater(mask_image_batch, 0.99))
            source_image_batch = tf.multiply(
                source_image_batch,
                tf.tile(tf.expand_dims(mask_images, 3), [1, 1, 1, 3]))

        target_image_batch, target_label_batch = dataset_utils.get_batches(
            'target',
            'test',
            tfrecord_dir,
            batch_size=args.batch_size,
            config=config)
        target_label_max_batch = tf.argmax(target_label_batch, 1)
        target_lateral_label_batch = (target_label_max_batch % 9) / 3
        target_head_label_batch = target_label_max_batch % 3

    # Call model
    da_model(source_image_batch, target_image_batch)

    #sess.run(tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()))
    # Load checkpoint
    saver = tf.train.Saver()
    ckpt = tf.train.get_checkpoint_state(ckpt_dir)
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)
        tf.logging.info('Checkpoint loaded from %s' % ckpt_dir)
    else:
        tf.logging.warn('Checkpoint not loaded')
        return

    target_lateral_one_hot_label = one_hot(target_lateral_label_batch)
    target_head_one_hot_label = one_hot(target_head_label_batch)
    lateral_label_num = tf.reduce_sum(target_lateral_one_hot_label, 0)
    head_label_num = tf.reduce_sum(target_head_one_hot_label, 0)

    lateral_command = get_command(da_model.target_lateral_logits)
    lateral_cmd_label = tf.multiply(target_lateral_one_hot_label,
                                    lateral_command)
    head_command = get_command(da_model.target_head_logits)
    head_cmd_label = tf.multiply(target_head_one_hot_label, head_command)

    model_lateral_one_hot = one_hot(
        tf.argmax(da_model.target_lateral_logits, 1))
    model_head_one_hot = one_hot(tf.argmax(da_model.target_head_logits, 1))

    lateral_correct_one = tf.reduce_sum(
        tf.multiply(model_lateral_one_hot, target_lateral_one_hot_label), 0)
    head_correct_one = tf.reduce_sum(
        tf.multiply(model_head_one_hot, target_head_one_hot_label), 0)

    target_lateral_one_hot_label = tf.expand_dims(target_lateral_one_hot_label,
                                                  2)
    target_head_one_hot_label = tf.expand_dims(target_head_one_hot_label, 2)

    lateral_confusion_table = tf.multiply(
        target_lateral_one_hot_label,
        tf.expand_dims(tf.nn.softmax(da_model.target_lateral_logits), 1))
    head_confusion_table = tf.multiply(
        target_head_one_hot_label,
        tf.expand_dims(tf.nn.softmax(da_model.target_head_logits), 1))

    #target_high_pass = op.neg_gaussian_filter(target_image_batch)
    #trans_high_pass  = op.neg_gaussian_filter(da_model.g_s2t)

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess, coord)

    try:
        #target_head_prob, target_lateral_prob = sess.run([tf.nn.softmax(da_model.target_head_logits), tf.nn.softmax(da_model.target_lateral_logits)])
        np_lateral_confusion = np.zeros((3, 3))
        np_head_confusion = np.zeros((3, 3))
        np_lateral_num = np.zeros(3)
        np_head_num = np.zeros(3)
        np_lateral_acc = np.zeros(3)
        np_head_acc = np.zeros(3)
        np_lateral_cmd = np.zeros(3)
        np_head_cmd = np.zeros(3)
        np_l_conf_array = np.zeros((1, 3, 3))
        np_h_conf_array = np.zeros((1, 3, 3))

        test_batch_num = int(math.floor(20000 / args.batch_size))

        for idx in range(test_batch_num):
            if idx == 0:
                #source, transferred_im, np_tar_high_pass, np_trans_high_pass = sess.run([source_image_batch, da_model.g_s2t, target_high_pass, trans_high_pass])
                source, transferred_im = sess.run(
                    [source_image_batch, da_model.g_s2t])
                import matplotlib.pyplot as plt
                import scipy.misc
                im_dir = os.path.join(ckpt_dir, 'im')
                if not os.path.exists(im_dir):
                    os.mkdir(im_dir)
                for im_idx in range(4):
                    scipy.misc.imsave(
                        os.path.join(im_dir, 'source%d.png' % im_idx),
                        source[im_idx, :, :, :])
                    scipy.misc.imsave(
                        os.path.join(im_dir, 'trans%d.png' % im_idx),
                        transferred_im[im_idx, :, :, :])
                    #scipy.misc.imsave(os.path.join(im_dir,'trans_high%d.png' % im_idx), np.squeeze(np_trans_high_pass[im_idx,:,:,:]))
                    #scipy.misc.imsave(os.path.join(im_dir,'tar_high%d.png' % im_idx), np.squeeze(np_tar_high_pass[im_idx]))

            t1, t2, x1, x2, y1, y2, z1, z2, a1, a2 = sess.run([
                lateral_confusion_table, head_confusion_table,
                tf.reduce_sum(lateral_confusion_table, 0),
                tf.reduce_sum(head_confusion_table, 0), lateral_label_num,
                head_label_num, lateral_correct_one, head_correct_one,
                tf.reduce_sum(lateral_cmd_label, 0),
                tf.reduce_sum(head_cmd_label, 0)
            ])
            np_lateral_confusion = np_lateral_confusion + x1
            np_head_confusion = np_head_confusion + x2
            np_lateral_num = np_lateral_num + y1
            np_head_num = np_head_num + y2
            np_lateral_acc = np_lateral_acc + z1
            np_head_acc = np_head_acc + z2
            np_lateral_cmd += a1
            np_head_cmd += a2
            np_l_conf_array = np.concatenate([np_l_conf_array, t1], 0)
            np_h_conf_array = np.concatenate([np_h_conf_array, t2], 0)
            printProgress(idx, test_batch_num, 'Progress', 'Complete', 1, 50)
        for i in range(3):
            np_lateral_confusion[i, :] /= np_lateral_num[i]
            np_head_confusion[i, :] /= np_head_num[i]
            np_lateral_acc[i] /= np_lateral_num[i]
            np_head_acc[i] /= np_head_num[i]
            np_lateral_cmd[i] /= np_lateral_num[i]
            np_head_cmd[i] /= np_head_num[i]
        print('\nlateral_confusion')
        print(np_lateral_confusion)
        print('head_confusion')
        print(np_head_confusion)
        print('lateral-accuracy')
        print(np_lateral_acc)
        print('head-accuracy')
        print(np_head_acc)
        print('lateral cmd')
        print(np_lateral_cmd)
        print('head cmd')
        print(np_head_cmd)

        for i in range(3):
            y = np_l_conf_array[:, i, 2] - np_l_conf_array[:, i, 0]
            z = np_l_conf_array[:, i, 0]
            print('when GT lateral is %d command mean : %.4f std : %.4f' %
                  (i, np.mean(y[z != 0]), np.std(y[z != 0])))
            y = np_h_conf_array[:, i, 2] - np_h_conf_array[:, i, 0]
            z = np_h_conf_array[:, i, 0]
            print('when GT head is %d command mean : %.4f std : %.4f' %
                  (i, np.mean(y[z != 0]), np.std(y[z != 0])))

    finally:
        coord.request_stop()
        coord.join(threads)
Пример #5
0
def train_model(config, save_dir):

    model_path = save_dir + 'Model/'
    make_savedir(model_path)

    #load dataset
    dataloader, len_dataset, dataset_name = load_dataloader(config, save_dir)
    in_ch = dataloader['train'].dataset[0][0].shape[0]
    classes = dataloader['train'].dataset.classes
    width  = dataloader['train'].dataset[0][0].shape[-1]
    phases = dataloader.keys()
    num_classes = len(classes)

    #load model
    for key in vars(config).keys():
        if ('model' in key) and vars(config)[key] != None:
            model_name = vars(config)[key]
            break

    model = get_model(model_name, in_ch, num_classes, config.preTrain)
    model.classes = classes
    model, device, parallel = config_device(config, model)

    criterion = nn.CrossEntropyLoss()
    optimizer = get_optim(config, model)

    if config.save_best:
        best_model_wts = copy.deepcopy(model.state_dict())

    if config.tfboard:
        writer, tfboard_path = get_tfboard_writer(save_dir, model_name, dataset_name)
        images, _ = next(iter(dataloader['train'])) 
        if parallel:
            writer.add_graph(model.module, images.to(device))
        else:
            writer.add_graph(model, images.to(device))

    num_epochs = config.epoch
    best_acc = 0

    for epoch in range(num_epochs):
        epoch_start_time = time.time()
        print('\n\nEpoch {}/{}'.format(epoch+1, num_epochs))
        print('-' * 60)

        for phase in phases:
            if phase == 'train':
                model.train()
            elif phase =='valid':
                model.eval()
            else:
                break
            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in dataloader[phase]:
                if len(labels) == 1:
                    continue
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                
                running_loss = loss.item() *inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
        
            epoch_loss = running_loss / len_dataset[phase]
            epoch_acc = running_corrects.double() / len_dataset[phase]

            print('<{}>\t\tLoss : {:.4f}\t\tAcc : {:.4f}'.format(
                phase, epoch_loss, epoch_acc))
        
            if config.tfboard:
                writer.add_scalar(phase+'_loss', epoch_loss, global_step=epoch+1)
                writer.add_scalar(phase+'_acc', epoch_acc, global_step=epoch+1)

        if config.tfboard:
            writer.add_figure('valid : predictions vs. actuals',
                plot_classes_preds(model, inputs, classes, labels), global_step=epoch+1)
            if epoch == 0:
                tb = program.TensorBoard()
                tb.configure(argv=[None, '--logdir', tfboard_path])
                url = tb.launch()
                webbrowser.open(url)

        time_elapsed = time.time()-epoch_start_time
        print('elapsed time : {:.0f}m {:.0f}s'.format(
            time_elapsed//60, time_elapsed%60))

        if epoch_acc > best_acc:
            best_acc = epoch_acc
            if config.save_best:
                best_model_wts = copy.deepcopy(model.state_dict())
            else:
                path = model_path+'_'.join([model_name, str(in_ch), str(num_classes),dataset_name, '{:.4f}'.format(best_acc.item()),time_stamp]) + '.pth'
                print('save ', path.split('/')[-1])
                if parallel :
                    # torch.save(model.module.state_dict(), path)
                    torch.save(model.module, path)
                else:
                    # torch.save(model.state_dict(), path)
                    torch.save(model, path)
        print("Best valid Acc: {:4f}".format(best_acc))
        print('-' * 60)

    if config.save_best:
        model.load_state_dict(best_model_wts)
        path = model_path+'_'.join([model_name, str(in_ch), str(num_classes), str(width), dataset_name, '{:.4f}'.format(best_acc.item()),time_stamp]) + '.pth'
        print('save ', path)
        print('save ', path.split('/')[-1])
        if parallel :
            # torch.save(model.module.state_dict(), path)
            torch.save(model.module, path)
        else:
            # torch.save(model.state_dict(), path)
            torch.save(model, path)

    if 'test' in phases:
        test_model(model, dataloader['test'], classes, device)

    if config.tfboard:
        writer.close()
    print('\nComplete training\n')
Пример #6
0
def evaluation(sess, args, config):
    model_type = config.get('config', 'experiment')
    base_dir = os.path.expanduser(config.get('config', 'basedir'))
    log_dir = os.path.join(base_dir, config.get('config', 'logdir'))
    tfrecord_dir = os.path.join(base_dir, config.get(model_type, 'tfrecord'))
    ckpt_dir = os.path.join(log_dir, utils.make_savedir(config))
    print(ckpt_dir)
    
    model_file = importlib.import_module(model_type)
    model = getattr(model_file, 'model')

    da_model = model(args, config)

    get_batches = getattr(dataset_utils, model_type)
    
    # Get test batches   
    with tf.name_scope('batch'):
        source_image_batch, source_label_batch, source_measure_batch, source_command_batch = get_batches('source', 'test', tfrecord_dir, batch_size=args.batch_size, config=config)

    # Call model
    da_model(source_image_batch, None, source_measure_batch)

    with tf.name_scope('objective'):
        da_model.create_objective(source_label_batch, source_command_batch)

    '''
        Metrics
    '''
    # tf.nn.softmax's default axis is -1 (last dimension)
    command_l = tf.nn.softmax(da_model.end[0])
    command_s = tf.nn.softmax(da_model.end[1])
    command_r = tf.nn.softmax(da_model.end[2])

    command_l_one_hot = slim.one_hot_encoding(tf.argmax(command_l, 1), args.num_label)
    command_s_one_hot = slim.one_hot_encoding(tf.argmax(command_s, 1), args.num_label)
    command_r_one_hot = slim.one_hot_encoding(tf.argmax(command_r, 1), args.num_label)

    # Label one-hot
    labels = slim.one_hot_encoding(tf.argmax(source_label_batch, 1), args.num_label)

    # Getting accuracy for each class
    command_l_labels = tf.reduce_sum(tf.expand_dims(source_command_batch[:,0], 1) * labels, 0)
    command_s_labels = tf.reduce_sum(tf.expand_dims(source_command_batch[:,1], 1) * labels, 0)
    command_r_labels = tf.reduce_sum(tf.expand_dims(source_command_batch[:,2], 1) * labels, 0)
    
    command_l_correct = tf.reduce_sum(tf.expand_dims(source_command_batch[:,0], 1) * tf.multiply(labels, command_l_one_hot), 0)
    command_s_correct = tf.reduce_sum(tf.expand_dims(source_command_batch[:,1], 1) * tf.multiply(labels, command_s_one_hot), 0)
    command_r_correct = tf.reduce_sum(tf.expand_dims(source_command_batch[:,2], 1) * tf.multiply(labels, command_r_one_hot), 0)

    sess.run(tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()))
    # Load checkpoint 
    saver = tf.train.Saver()
    ckpt = tf.train.get_checkpoint_state(ckpt_dir)
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)
        tf.logging.info('Checkpoint loaded from %s' % ckpt_dir)
    else:
        tf.logging.warn('Checkpoint not loaded')
        return

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess, coord)

    try:  
        left_label_num = np.zeros(5)
        str_label_num = np.zeros(5)
        right_label_num = np.zeros(5)
        left_acc = np.zeros(5)
        str_acc = np.zeros(5)
        right_acc = np.zeros(5)

        for idx in range(args.num_eval):
            command, steer_label, np_im, steer_pred_l, steer_pred_s, steer_pred_r, cmd_l_label_num, cmd_s_label_num, cmd_r_label_num, cmd_l_correct, cmd_s_correct, cmd_r_correct \
                 = sess.run([source_command_batch, source_label_batch, source_image_batch, command_l, command_s, command_r, command_l_labels, command_s_labels, command_r_labels, command_l_correct, command_s_correct, command_r_correct])


            if args.print_info:
                print('COMMAND LABEL')
                print(command)
                print('STEER')
                print(steer_label)
                print('-'*20)
                accuracy(steer_pred_l, steer_pred_s, steer_pred_r, args.batch_size, command, args.num_label)
                print(cmd_l_label_num)
                print(cmd_s_label_num)
                print(cmd_r_label_num)
                print('-'*10)
                print(cmd_l_correct)
                print(cmd_s_correct)
                print(cmd_r_correct) 

            left_label_num += cmd_l_label_num
            str_label_num += cmd_s_label_num
            right_label_num += cmd_r_label_num
            left_acc += cmd_l_correct
            str_acc += cmd_s_correct
            right_acc += cmd_r_correct

        for i in range(args.num_label):
            left_acc[i] = left_acc[i] / left_label_num[i]
            str_acc[i] = str_acc[i] / str_label_num[i]
            right_acc[i] = right_acc[i] / right_label_num[i]

        print('LEFT COMMAND ACCURACY')
        print(left_acc)
        print('STRAIGHT COMMAND ACCURACY')
        print(str_acc)
        print('RIGHT COMMAND ACCURACY')
        print(right_acc)

           
    finally:
        coord.request_stop()
        coord.join(threads)