示例#1
0
        def write_grid(base_id, his_lens, gt_lens, pred_lens, his_landmarks,
                       gt_landmarks, pred_landmarks):
            """Python function."""
            tmp_dir = os.path.join(log_dir, 'tmp')
            utils.force_mkdir(tmp_dir)
            video_proc_lib = VideoProc()
            #############################
            ## Plot the history frames ##
            #############################
            his_video = np.zeros((his_lens[0], img_size, img_size, 3),
                                 dtype=np.float32)
            for t in xrange(his_lens[0]):
                his_video[t] = utils.visualize_landmarks(
                    his_video[t], his_landmarks[0][t])
                his_video[t] = utils.visualize_h36m_skeleton(
                    his_video[t], his_landmarks[0][t])
                his_video[t] = utils.visualize_boundary(his_video[t],
                                                        colormap='green')
            #################################
            ## Plot the gt (future) frames ##
            #################################
            gt_video = np.zeros((gt_lens[0], img_size, img_size, 3),
                                dtype=np.float32)
            for t in xrange(gt_lens[0]):
                gt_video[t] = utils.visualize_landmarks(
                    gt_video[t], gt_landmarks[0][t])
                gt_video[t] = utils.visualize_h36m_skeleton(
                    gt_video[t], gt_landmarks[0][t])
                gt_video[t] = utils.visualize_boundary(gt_video[t],
                                                       colormap='blue')
            merged_video = np.concatenate((his_video, gt_video), axis=0)
            video_proc_lib.save_img_seq_to_video(merged_video,
                                                 log_dir,
                                                 '%02d_gt.gif' % base_id,
                                                 frame_rate=7.5,
                                                 codec=None,
                                                 override=True)
            ###################################
            ## Plot the pred (future) frames ##
            ###################################
            raw_gif_list = []
            for i in xrange(batch_size):
                print(base_id * batch_size + i)
                pred_video = np.zeros((pred_lens[i], img_size, img_size, 3),
                                      dtype=np.float32)
                for t in xrange(pred_lens[i]):
                    pred_video[t] = utils.visualize_landmarks(
                        pred_video[t], pred_landmarks[i][t])
                    pred_video[t] = utils.visualize_h36m_skeleton(
                        pred_video[t], pred_landmarks[i][t])
                    pred_video[t] = utils.visualize_boundary(pred_video[t],
                                                             colormap='red')
                merged_video = np.concatenate((his_video, pred_video), axis=0)
                video_proc_lib.save_img_seq_to_video(merged_video,
                                                     log_dir,
                                                     '%02d_pred%02d.gif' %
                                                     (base_id, i),
                                                     frame_rate=7.5,
                                                     codec=None,
                                                     override=True)
                raw_gif_list.append('%02d_pred%02d.gif' % (base_id, i))
            video_proc_lib.merge_video_side_by_side(log_dir,
                                                    raw_gif_list,
                                                    '%02d_pred.gif' % base_id,
                                                    override=True)

            return 0
示例#2
0
def main(_):
    train_dir = os.path.join(FLAGS.checkpoint_dir, FLAGS.model_name, 'train')
    utils.force_mkdir(os.path.join(FLAGS.checkpoint_dir, FLAGS.model_name))
    utils.force_mkdir(train_dir)

    g = tf.Graph()
    with g.as_default():
        with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
            global_step = slim.get_or_create_global_step()
            ###########
            ## model ##
            ###########
            model = MTVAEPredModel(FLAGS)
            ##########
            ## data ##
            ##########
            train_data = model.get_inputs(FLAGS.inp_dir,
                                          FLAGS.dataset_name,
                                          'train',
                                          FLAGS.batch_size,
                                          is_training=True)
            inputs = model.preprocess(train_data, is_training=True)
            ##############
            ## model_fn ##
            ##############
            model_fn = model.get_model_fn(is_training=True,
                                          use_prior=FLAGS.use_prior,
                                          reuse=False)
            outputs = model_fn(inputs)
            ##################
            ## train_scopes ##
            ##################
            train_scopes = ['seq_enc', 'latent_enc', 'latent_dec', 'fut_dec']
            init_scopes = train_scopes
            if FLAGS.init_model:
                init_fn = model.get_init_fn(init_scopes)
            else:
                init_fn = None
            ##########
            ## loss ##
            ##########
            total_loss, loss_dict = model.get_loss(global_step, inputs,
                                                   outputs)
            reg_loss = model.get_regularization_loss(outputs, train_scopes)
            print_op = model.print_running_loss(global_step, loss_dict)
            ###############
            ## optimizer ##
            ###############
            optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate,
                                               beta1=0.9,
                                               beta2=0.999)

            ##############
            ## train_op ##
            ##############
            train_op = model.get_train_op_for_scope(total_loss + reg_loss,
                                                    optimizer, train_scopes)
            with tf.control_dependencies([print_op]):
                train_op = tf.identity(train_op)
            ###########
            ## saver ##
            ###########
            saver = tf.train.Saver(
                max_to_keep=np.minimum(5, FLAGS.worker_replicas + 1))
            ##############
            ## training ##
            ##############
            slim.learning.train(train_op=train_op,
                                logdir=train_dir,
                                init_fn=init_fn,
                                master=FLAGS.master,
                                is_chief=(FLAGS.task == 0),
                                number_of_steps=FLAGS.max_number_of_steps,
                                saver=saver,
                                save_summaries_secs=FLAGS.save_summaries_secs,
                                save_interval_secs=FLAGS.save_interval_secs)
def train(conf, train_shape_list, train_data_list, val_data_list,
          all_train_data_list):
    # create training and validation datasets and data loaders
    data_features = ['pcs', 'pc_pxids', 'pc_movables', 'gripper_img_target', 'gripper_direction_camera', 'gripper_forward_direction_camera', \
            'result', 'cur_dir', 'shape_id', 'trial_id', 'is_original']

    # load network model
    model_def = utils.get_model_module(conf.model_version)

    # create models
    network = model_def.Network(conf.feat_dim)
    utils.printout(conf.flog, '\n' + str(network) + '\n')

    # create optimizers
    network_opt = torch.optim.Adam(network.parameters(),
                                   lr=conf.lr,
                                   weight_decay=conf.weight_decay)

    # learning rate scheduler
    network_lr_scheduler = torch.optim.lr_scheduler.StepLR(
        network_opt, step_size=conf.lr_decay_every, gamma=conf.lr_decay_by)

    # create logs
    if not conf.no_console_log:
        header = '     Time    Epoch     Dataset    Iteration    Progress(%)       LR    TotalLoss'
    if not conf.no_tb_log:
        # https://github.com/lanpa/tensorboard-pytorch
        from tensorboardX import SummaryWriter
        train_writer = SummaryWriter(os.path.join(conf.exp_dir, 'train'))
        val_writer = SummaryWriter(os.path.join(conf.exp_dir, 'val'))

    # send parameters to device
    network.to(conf.device)
    utils.optimizer_to_device(network_opt, conf.device)

    # load dataset
    train_dataset = SAPIENVisionDataset([conf.primact_type], conf.category_types, data_features, conf.buffer_max_num, \
            abs_thres=conf.abs_thres, rel_thres=conf.rel_thres, dp_thres=conf.dp_thres, img_size=conf.img_size, no_true_false_equal=conf.no_true_false_equal)

    val_dataset = SAPIENVisionDataset([conf.primact_type], conf.category_types, data_features, conf.buffer_max_num, \
            abs_thres=conf.abs_thres, rel_thres=conf.rel_thres, dp_thres=conf.dp_thres, img_size=conf.img_size, no_true_false_equal=conf.no_true_false_equal)
    val_dataset.load_data(val_data_list)
    utils.printout(conf.flog, str(val_dataset))

    val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=conf.batch_size, shuffle=False, pin_memory=True, \
            num_workers=0, drop_last=True, collate_fn=utils.collate_feats, worker_init_fn=utils.worker_init_fn)
    val_num_batch = len(val_dataloader)

    # create a data generator
    datagen = DataGen(conf.num_processes_for_datagen, conf.flog)

    # sample succ
    if conf.sample_succ:
        sample_succ_list = []
        sample_succ_dirs = []

    # start training
    start_time = time.time()

    last_train_console_log_step, last_val_console_log_step = None, None

    # if resume
    start_epoch = 0
    if conf.resume:
        # figure out the latest epoch to resume
        for item in os.listdir(os.path.join(conf.exp_dir, 'ckpts')):
            if item.endswith('-train_dataset.pth'):
                start_epoch = int(item.split('-')[0])

        # load states for network, optimizer, lr_scheduler, sample_succ_list
        data_to_restore = torch.load(
            os.path.join(conf.exp_dir, 'ckpts',
                         '%d-network.pth' % start_epoch))
        network.load_state_dict(data_to_restore)
        data_to_restore = torch.load(
            os.path.join(conf.exp_dir, 'ckpts',
                         '%d-optimizer.pth' % start_epoch))
        network_opt.load_state_dict(data_to_restore)
        data_to_restore = torch.load(
            os.path.join(conf.exp_dir, 'ckpts',
                         '%d-lr_scheduler.pth' % start_epoch))
        network_lr_scheduler.load_state_dict(data_to_restore)

        # rmdir and make a new dir for the current sample-succ directory
        old_sample_succ_dir = os.path.join(
            conf.data_dir, 'epoch-%04d_sample-succ' % (start_epoch - 1))
        utils.force_mkdir(old_sample_succ_dir)

    # train for every epoch
    for epoch in range(start_epoch, conf.epochs):
        ### collect data for the current epoch
        if epoch > start_epoch:
            utils.printout(
                conf.flog,
                f'  [{strftime("%H:%M:%S", time.gmtime(time.time()-start_time)):>9s} Waiting epoch-{epoch} data ]'
            )
            train_data_list = datagen.join_all()
            utils.printout(
                conf.flog,
                f'  [{strftime("%H:%M:%S", time.gmtime(time.time()-start_time)):>9s} Gathered epoch-{epoch} data ]'
            )
            cur_data_folders = []
            for item in train_data_list:
                item = '/'.join(item.split('/')[:-1])
                if item not in cur_data_folders:
                    cur_data_folders.append(item)
            for cur_data_folder in cur_data_folders:
                with open(os.path.join(cur_data_folder, 'data_tuple_list.txt'),
                          'w') as fout:
                    for item in train_data_list:
                        if cur_data_folder == '/'.join(item.split('/')[:-1]):
                            fout.write(item.split('/')[-1] + '\n')

            # load offline-generated sample-random data
            for item in all_train_data_list:
                valid_id_l = conf.num_interaction_data_offline + conf.num_interaction_data * (
                    epoch - 1)
                valid_id_r = conf.num_interaction_data_offline + conf.num_interaction_data * epoch
                if valid_id_l <= int(item.split('_')[-1]) < valid_id_r:
                    train_data_list.append(item)

        ### start generating data for the next epoch
        # sample succ
        if conf.sample_succ:
            if conf.resume and epoch == start_epoch:
                sample_succ_list = torch.load(
                    os.path.join(conf.exp_dir, 'ckpts',
                                 '%d-sample_succ_list.pth' % start_epoch))
            else:
                torch.save(
                    sample_succ_list,
                    os.path.join(conf.exp_dir, 'ckpts',
                                 '%d-sample_succ_list.pth' % epoch))
            for item in sample_succ_list:
                datagen.add_one_recollect_job(item[0], item[1], item[2],
                                              item[3], item[4], item[5],
                                              item[6])
            sample_succ_list = []
            sample_succ_dirs = []
            cur_sample_succ_dir = os.path.join(
                conf.data_dir, 'epoch-%04d_sample-succ' % epoch)
            utils.force_mkdir(cur_sample_succ_dir)

        # start all jobs
        datagen.start_all()
        utils.printout(
            conf.flog,
            f'  [ {strftime("%H:%M:%S", time.gmtime(time.time()-start_time)):>9s} Started generating epoch-{epoch+1} data ]'
        )

        ### load data for the current epoch
        if conf.resume and epoch == start_epoch:
            train_dataset = torch.load(
                os.path.join(conf.exp_dir, 'ckpts',
                             '%d-train_dataset.pth' % start_epoch))
        else:
            train_dataset.load_data(train_data_list)
        utils.printout(conf.flog, str(train_dataset))
        train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=conf.batch_size, shuffle=True, pin_memory=True, \
                num_workers=0, drop_last=True, collate_fn=utils.collate_feats, worker_init_fn=utils.worker_init_fn)
        train_num_batch = len(train_dataloader)

        ### print log
        if not conf.no_console_log:
            utils.printout(conf.flog, f'training run {conf.exp_name}')
            utils.printout(conf.flog, header)

        train_batches = enumerate(train_dataloader, 0)
        val_batches = enumerate(val_dataloader, 0)

        train_fraction_done = 0.0
        val_fraction_done = 0.0
        val_batch_ind = -1

        ### train for every batch
        for train_batch_ind, batch in train_batches:
            train_fraction_done = (train_batch_ind + 1) / train_num_batch
            train_step = epoch * train_num_batch + train_batch_ind

            log_console = not conf.no_console_log and (last_train_console_log_step is None or \
                    train_step - last_train_console_log_step >= conf.console_log_interval)
            if log_console:
                last_train_console_log_step = train_step

            # save checkpoint
            if train_batch_ind == 0:
                with torch.no_grad():
                    utils.printout(conf.flog, 'Saving checkpoint ...... ')
                    torch.save(
                        network.state_dict(),
                        os.path.join(conf.exp_dir, 'ckpts',
                                     '%d-network.pth' % epoch))
                    torch.save(
                        network_opt.state_dict(),
                        os.path.join(conf.exp_dir, 'ckpts',
                                     '%d-optimizer.pth' % epoch))
                    torch.save(
                        network_lr_scheduler.state_dict(),
                        os.path.join(conf.exp_dir, 'ckpts',
                                     '%d-lr_scheduler.pth' % epoch))
                    torch.save(
                        train_dataset,
                        os.path.join(conf.exp_dir, 'ckpts',
                                     '%d-train_dataset.pth' % epoch))
                    utils.printout(conf.flog, 'DONE')

            # set models to training mode
            network.train()

            # forward pass (including logging)
            total_loss, whole_feats, whole_pcs, whole_pxids, whole_movables = forward(batch=batch, data_features=data_features, network=network, conf=conf, is_val=False, \
                    step=train_step, epoch=epoch, batch_ind=train_batch_ind, num_batch=train_num_batch, start_time=start_time, \
                    log_console=log_console, log_tb=not conf.no_tb_log, tb_writer=train_writer, lr=network_opt.param_groups[0]['lr'])

            # optimize one step
            network_opt.zero_grad()
            total_loss.backward()
            network_opt.step()
            network_lr_scheduler.step()

            # sample succ
            if conf.sample_succ:
                network.eval()

                with torch.no_grad():
                    # sample a random EE orientation
                    random_up = torch.randn(conf.batch_size,
                                            3).float().to(conf.device)
                    random_forward = torch.randn(conf.batch_size,
                                                 3).float().to(conf.device)
                    random_left = torch.cross(random_up, random_forward)
                    random_forward = torch.cross(random_left, random_up)
                    random_dirs1 = F.normalize(random_up, dim=1).float()
                    random_dirs2 = F.normalize(random_forward, dim=1).float()

                    # test over the entire image
                    whole_pc_scores1 = network.inference_whole_pc(
                        whole_feats, random_dirs1, random_dirs2)  # B x N
                    whole_pc_scores2 = network.inference_whole_pc(
                        whole_feats, -random_dirs1, random_dirs2)  # B x N

                    # add to the sample_succ_list if wanted
                    ss_cur_dir = batch[data_features.index('cur_dir')]
                    ss_shape_id = batch[data_features.index('shape_id')]
                    ss_trial_id = batch[data_features.index('trial_id')]
                    ss_is_original = batch[data_features.index('is_original')]
                    for i in range(conf.batch_size):
                        valid_id_l = conf.num_interaction_data_offline + conf.num_interaction_data * (
                            epoch - 1)
                        valid_id_r = conf.num_interaction_data_offline + conf.num_interaction_data * epoch

                        if ('sample-succ' not in ss_cur_dir[i]) and (ss_is_original[i]) and (ss_cur_dir[i] not in sample_succ_dirs) \
                                and (valid_id_l <= int(ss_trial_id[i]) < valid_id_r):
                            sample_succ_dirs.append(ss_cur_dir[i])

                            # choose one from the two options
                            gt_movable = whole_movables[i].cpu().numpy()

                            whole_pc_score1 = whole_pc_scores1[i].cpu().numpy(
                            ) * gt_movable
                            whole_pc_score1[whole_pc_score1 < 0.5] = 0
                            whole_pc_score_sum1 = np.sum(
                                whole_pc_score1) + 1e-12

                            whole_pc_score2 = whole_pc_scores2[i].cpu().numpy(
                            ) * gt_movable
                            whole_pc_score2[whole_pc_score2 < 0.5] = 0
                            whole_pc_score_sum2 = np.sum(
                                whole_pc_score2) + 1e-12

                            choose1or2_ratio = whole_pc_score_sum1 / (
                                whole_pc_score_sum1 + whole_pc_score_sum2)
                            random_dir1 = random_dirs1[i].cpu().numpy()
                            random_dir2 = random_dirs2[i].cpu().numpy()
                            if np.random.random() < choose1or2_ratio:
                                whole_pc_score = whole_pc_score1
                            else:
                                whole_pc_score = whole_pc_score2
                                random_dir1 = -random_dir1

                            # sample <X, Y> on each img
                            pp = whole_pc_score + 1e-12
                            ptid = np.random.choice(len(whole_pc_score),
                                                    1,
                                                    p=pp / pp.sum())
                            X = whole_pxids[i, ptid, 0].item()
                            Y = whole_pxids[i, ptid, 1].item()

                            # add job to the queue
                            str_cur_dir1 = ',' + ','.join(
                                ['%f' % elem for elem in random_dir1])
                            str_cur_dir2 = ',' + ','.join(
                                ['%f' % elem for elem in random_dir2])
                            sample_succ_list.append((conf.offline_data_dir, str_cur_dir1, str_cur_dir2, \
                                    ss_cur_dir[i].split('/')[-1], cur_sample_succ_dir, X, Y))

            # validate one batch
            while val_fraction_done <= train_fraction_done and val_batch_ind + 1 < val_num_batch:
                val_batch_ind, val_batch = next(val_batches)

                val_fraction_done = (val_batch_ind + 1) / val_num_batch
                val_step = (epoch + val_fraction_done) * train_num_batch - 1

                log_console = not conf.no_console_log and (last_val_console_log_step is None or \
                        val_step - last_val_console_log_step >= conf.console_log_interval)
                if log_console:
                    last_val_console_log_step = val_step

                # set models to evaluation mode
                network.eval()

                with torch.no_grad():
                    # forward pass (including logging)
                    __ = forward(batch=val_batch, data_features=data_features, network=network, conf=conf, is_val=True, \
                            step=val_step, epoch=epoch, batch_ind=val_batch_ind, num_batch=val_num_batch, start_time=start_time, \
                            log_console=log_console, log_tb=not conf.no_tb_log, tb_writer=val_writer, lr=network_opt.param_groups[0]['lr'])
示例#4
0
def main(_):
    params = add_attributes(FLAGS, MODEL_SPECS[FLAGS.model_version])
    model_dir = os.path.join(params.checkpoint_dir, params.model_name, 'train')
    img_dir = os.path.join(params.checkpoint_dir, 'analogy_comparison', 'imgs')
    log_dir = os.path.join(params.checkpoint_dir, 'analogy_comparison',
                           params.model_version)
    assert os.path.isdir(model_dir)
    utils.force_mkdir(os.path.join(params.checkpoint_dir,
                                   'analogy_comparison'))
    utils.force_mkdir(log_dir)
    utils.force_mkdir(img_dir)
    video_proc_lib = VideoProc()
    ##################
    ## Load dataset ##
    ##################
    filename_queue = get_dataset(FLAGS.inp_dir, FLAGS.dataset_name)
    dataset_size = len(filename_queue[0])
    assert params.min_input_length == params.max_input_length
    init_length = params.min_input_length
    sample_length = params.max_length + init_length

    np_data = dict()
    np_data['landmarks'] = np.zeros(
        (dataset_size, sample_length, params.keypoint_dim, 2),
        dtype=np.float32)
    np_data['actual_frame_id'] = []
    np_data['vid_id'] = []
    for i in xrange(dataset_size):
        mid_frame = int(filename_queue[1][i])
        keyframes = np.arange(mid_frame - init_length,
                              mid_frame + params.max_length)
        _, landmarks = input_generator.load_pts_seq(params.dataset_name,
                                                    filename_queue[0][i],
                                                    sample_length, keyframes)
        landmarks = np.reshape(landmarks,
                               (sample_length, params.keypoint_dim, 2))
        np_data['landmarks'][i] = landmarks
        # TODO(xcyan): hacky implemetation.
        np_data['vid_id'].append(filename_queue[2][i])
        np_data['actual_frame_id'].append(filename_queue[3][i] + mid_frame)
    ####################
    ## save html page ##
    ####################
    save_html_page(img_dir, dataset_size / 2)
    #########################
    ## Prepare output dict ##
    #########################
    genseq = dict()
    genseq['AB_id'] = []
    genseq['CD_id'] = []
    genseq['A_start'] = []
    genseq['B_start'] = []
    genseq['C_start'] = []
    genseq['pred_start'] = []
    genseq['pred_keypoints'] = []
    ####################
    ## Build tf.graph ##
    ####################
    g = tf.Graph()
    with g.as_default():
        keypointClass = MODEL_TO_CLASS[params.model_version]
        model = keypointClass(params)
        scopes = MODEL_TO_SCOPE[params.model_version]

        eval_data = model.get_inputs_from_placeholder(dataset_size,
                                                      params.batch_size)
        inputs = model.preprocess(eval_data,
                                  is_training=False,
                                  load_image=False)
        assert (not params.use_prior)
        model_fn = model.get_model_fn(is_training=False, reuse=False)
        outputs = model_fn(inputs)
        variables_to_restore = slim.get_variables_to_restore(scopes)
        #######################
        ## Restore variables ##
        #######################
        checkpoint_path = tf.train.latest_checkpoint(model_dir)
        restorer = tf.train.Saver(variables_to_restore)
        ####################
        ## launch session ##
        ####################
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            restorer.restore(sess, checkpoint_path)

            for i in xrange(dataset_size / 2):
                print(i)
                AB_idx, CD_idx = i * 2, i * 2 + 1
                A_lms, B_lms, C_lms, predD_lms = sess.run(
                    [
                        inputs['A_landmarks'], inputs['B_landmarks'],
                        inputs['C_landmarks'], outputs['D_landmarks']
                    ],
                    feed_dict={
                        eval_data['AB_landmarks']:
                        np_data['landmarks'][AB_idx:AB_idx + 1],
                        eval_data['CD_landmarks']:
                        np_data['landmarks'][CD_idx:CD_idx + 1]
                    })
                batch_output = np.copy(predD_lms) * 2 - 1.0
                A_imgs = utils.visualize_h36m_skeleton_batch(
                    np.copy(A_lms[0]), params.img_size)
                B_imgs = utils.visualize_h36m_skeleton_batch(
                    np.copy(B_lms[0]), params.img_size)
                C_imgs = utils.visualize_h36m_skeleton_batch(
                    np.copy(C_lms[0]), params.img_size)
                predD_imgs = utils.visualize_h36m_skeleton_batch(
                    np.copy(predD_lms[0]), params.img_size)
                ################
                ## Save video ##
                ################
                vid_file, out_dict = run_visualization(video_proc_lib, log_dir,
                                                       i, params.model_version,
                                                       A_imgs, B_imgs, C_imgs,
                                                       predD_imgs)
                save_images(img_dir, i, params.model_version, out_dict)
                ##########
                ## save ##
                ##########
                genseq['AB_id'].append(np_data['vid_id'][AB_idx])
                genseq['CD_id'].append(np_data['vid_id'][CD_idx])
                genseq['A_start'].append(np_data['actual_frame_id'][AB_idx] -
                                         init_length)
                genseq['B_start'].append(np_data['actual_frame_id'][AB_idx])
                genseq['C_start'].append(np_data['actual_frame_id'][CD_idx] -
                                         init_length)
                genseq['pred_start'].append(np_data['actual_frame_id'][CD_idx])
                genseq['pred_keypoints'].append(batch_output)
                print('%g %g' % (np.amin(batch_output[:, :, :, 0]),
                                 np.amax(batch_output[:, :, :, 0])))
                print('%g %g' % (np.amin(batch_output[:, :, :, 1]),
                                 np.amax(batch_output[:, :, :, 1])))

    utils.save_python_objects(genseq, os.path.join(log_dir,
                                                   'test_analogy.pkl'))
示例#5
0
flags.DEFINE_string('embedding_file', 'new_glove.txt', '')
flags.DEFINE_string('titles_file', 'AbsSumm_title_60k.pkl', '')
flags.DEFINE_string('paras_file', 'AbsSumm_text_60k.pkl', '')
flags.DEFINE_float('dropout', 0.2, '')
flags.DEFINE_integer('test_size', 5000, '')
flags.DEFINE_integer('train_size', 60000, '')
flags.DEFINE_integer('maximum_iterations', 60, '')

FLAGS = flags.FLAGS
data_root_dir = './workspace'
paras_file = FLAGS.paras_file
titles_file = FLAGS.titles_file
# embedding_file = 'glove.6B.100d.txt'
embedding_file = FLAGS.embedding_file
ckpt_dir = './checkpoints'
utils.force_mkdir(ckpt_dir)

workspace_path = lambda file_path: os.path.join(data_root_dir, file_path)
paras_file, titles_file, embedding_file = workspace_path(paras_file), \
  workspace_path(titles_file), workspace_path(embedding_file)

input_data, inputs_for_tf, input_placeholders = input_generator.get_inputs(
    paras_file, titles_file, embedding_file, FLAGS)

###########
## Model ##
###########
outputs = RNNModel(inputs_for_tf,
                   FLAGS,
                   is_training=IS_TRAINING,
                   multirnn=False)
示例#6
0
def main(_):
    params = add_attributes(FLAGS, MODEL_SPECS[FLAGS.model_version])
    model_dir = os.path.join(FLAGS.checkpoint_dir, params.model_name, 'train')
    img_dir = os.path.join(params.checkpoint_dir, 'gensample_comparison',
                           'imgs')
    log_dir = os.path.join(params.checkpoint_dir, 'gensample_comparison',
                           params.model_version)
    if (params.model_version in ['PredLSTM']):
        params.batch_size = 1
    assert os.path.isdir(model_dir)
    utils.force_mkdir(
        os.path.join(params.checkpoint_dir, 'gensample_comparison'))
    utils.force_mkdir(log_dir)
    utils.force_mkdir(img_dir)
    video_proc_utils = VideoProc()
    ##################
    ## load dataset ##
    ##################
    filename_queue = get_dataset(params.inp_dir, params.dataset_name)
    dataset_size = len(filename_queue[0])
    assert params.min_input_length == params.max_input_length
    init_length = params.min_input_length
    sample_length = params.max_length + init_length
    #
    np_data = dict()
    np_data['vid_id'] = []
    np_data['actual_frame_id'] = []
    np_data['landmarks'] = np.zeros(
        (dataset_size, sample_length, params.keypoint_dim, 2),
        dtype=np.float32)
    for i in xrange(dataset_size):
        mid_frame = int(filename_queue[1][i])
        keyframes = np.arange(mid_frame - init_length,
                              mid_frame + params.max_length)
        _, landmarks = input_generator.load_pts_seq(params.dataset_name,
                                                    filename_queue[0][i],
                                                    sample_length, keyframes)
        landmarks = np.reshape(landmarks,
                               (sample_length, params.keypoint_dim, 2))
        np_data['landmarks'][i] = landmarks
        # TODO(xcyan): hacky implementation.
        base_frame_id = int(H36M_testcases[i][1][-3:]) * SEQ_LEN
        np_data['vid_id'].append(int(H36M_testcases[i][0]))
        np_data['actual_frame_id'].append(base_frame_id + mid_frame)
    #########################
    ## Prepare output dict ##
    #########################
    genseq = dict()
    genseq['video_id'] = []
    genseq['ob_start'] = []
    genseq['pred_start'] = []
    genseq['pred_keypoints'] = []
    #################
    ## Build graph ##
    #################
    g = tf.Graph()
    with g.as_default():
        keypointClass = MODEL_TO_CLASS[params.model_version]
        model = keypointClass(params)
        scopes = MODEL_TO_SCOPE[params.model_version]

        eval_data = model.get_inputs_from_placeholder(dataset_size,
                                                      params.batch_size)
        inputs = model.preprocess(eval_data,
                                  is_training=False,
                                  load_image=False)
        ##############
        ## model_fn ##
        ##############
        if not (params.model_version in ['PredLSTM']):
            model_fn = model.get_sample_fn(is_training=False,
                                           use_prior=params.use_prior,
                                           reuse=False)
        else:
            model_fn = model.get_sample_fn(is_training=False, reuse=False)
        outputs = model_fn(inputs)
        #######################
        ## restore variables ##
        #######################
        variables_to_restore = slim.get_variables_to_restore(scopes)
        restorer = tf.train.Saver(variables_to_restore)
        checkpoint_path = tf.train.latest_checkpoint(model_dir)
        ####################
        ## launch session ##
        ####################
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            restorer.restore(sess, checkpoint_path)
            for i in xrange(dataset_size):
                print(i)
                raw_gif_list = []
                curr_landmarks = np.tile(np_data['landmarks'][i:i + 1],
                                         (params.batch_size, 1, 1, 1))
                his_lms, fut_lms, pred_lms = sess.run(
                    [
                        inputs['his_landmarks'], inputs['fut_landmarks'],
                        outputs['fut_landmarks']
                    ],
                    feed_dict={eval_data['landmarks']: curr_landmarks})

                batch_output = np.copy(pred_lms) * 2 - 1.0
                for run_id in xrange(params.batch_size):
                    his_imgs = utils.visualize_h36m_skeleton_batch(
                        np.copy(his_lms[run_id]), params.img_size)
                    fut_imgs = utils.visualize_h36m_skeleton_batch(
                        np.copy(fut_lms[run_id]), params.img_size)
                    pred_imgs = utils.visualize_h36m_skeleton_batch(
                        np.copy(pred_lms[run_id]), params.img_size)
                    ################
                    ## save video ##
                    ################
                    if run_id == 0:
                        vid_file, out_dict = run_visualization(
                            video_proc_utils, log_dir, i, 'gt', his_imgs,
                            fut_imgs)
                        save_images(img_dir, i, 'gt', out_dict)

                    vid_file, out_dict = run_visualization(
                        video_proc_utils, log_dir, i,
                        params.model_version + '_%02d' % run_id, his_imgs,
                        pred_imgs)
                    save_images(img_dir, i,
                                params.model_version + '_%02d' % run_id,
                                out_dict)
                    raw_gif_list.append(vid_file)
                #################
                ## Merge video ##
                #################
                if params.batch_size > 1:
                    video_proc_utils.merge_video_side_by_side(
                        log_dir,
                        raw_gif_list,
                        '%02d_merged_%s.gif' % (i, params.model_version),
                        override=True)
                ##########
                ## Save ##
                ##########
                genseq['video_id'].append(np_data['vid_id'][i])
                genseq['ob_start'].append(np_data['actual_frame_id'][i] -
                                          init_length)
                genseq['pred_start'].append(np_data['actual_frame_id'][i])
                genseq['pred_keypoints'].append(batch_output)
                print('%g %g' % (np.amin(batch_output[:, :, :, 0]),
                                 np.amax(batch_output[:, :, :, 0])))
                print('%g %g' % (np.amin(batch_output[:, :, :, 1]),
                                 np.amax(batch_output[:, :, :, 1])))

    utils.save_python_objects(genseq, os.path.join(log_dir, 'test.pkl'))