예제 #1
0
def mode_8(sess, graph, save_path):
    """ to find high-openshot-penalty data in 1000 real data
    """
    real_data = np.load(FLAGS.data_path)[:, :FLAGS.seq_length, :, :]
    print('real_data.shape', real_data.shape)
    data_factory = DataFactory(real_data)
    train_data, valid_data = data_factory.fetch_data()
    # placeholder tensor
    real_data_t = graph.get_tensor_by_name('real_data:0')
    matched_cond_t = graph.get_tensor_by_name('matched_cond:0')
    # result tensor
    heuristic_penalty_pframe = graph.get_tensor_by_name(
        'Critic/C_inference/heuristic_penalty/Min:0')
    # 'Generator/G_loss/C_inference/linear_result/Reshape:0')

    if not os.path.exists(save_path):
        os.makedirs(save_path)
    real_hp_pframe_all = []
    for batch_id in range(train_data['A'].shape[0] // FLAGS.batch_size):
        index_id = batch_id * FLAGS.batch_size
        real_data = train_data['B'][index_id:index_id + FLAGS.batch_size]
        cond_data = train_data['A'][index_id:index_id + FLAGS.batch_size]
        # real
        feed_dict = {real_data_t: real_data, matched_cond_t: cond_data}
        real_hp_pframe = sess.run(heuristic_penalty_pframe,
                                  feed_dict=feed_dict)
        real_hp_pframe_all.append(real_hp_pframe)
    real_hp_pframe_all = np.concatenate(real_hp_pframe_all, axis=0)
    print(real_hp_pframe_all.shape)
    real_hp_pdata = np.mean(real_hp_pframe_all, axis=1)
    mean_ = np.mean(real_hp_pdata)
    std_ = np.std(real_hp_pdata)
    print(mean_)
    print(std_)

    concat_AB = np.concatenate([train_data['A'], train_data['B']], axis=-1)
    recoverd = data_factory.recover_data(concat_AB)
    for i, v in enumerate(real_hp_pdata):
        if v > (mean_ + 2 * std_):
            print('bad', i, v)
            game_visualizer.plot_data(recoverd[i],
                                      recoverd.shape[1],
                                      file_path=save_path + 'bad_' + str(i) +
                                      '_' + str(v) + '.mp4',
                                      if_save=True)
        if v < 0.0025:
            print('good', i, v)
            game_visualizer.plot_data(recoverd[i],
                                      recoverd.shape[1],
                                      file_path=save_path + 'good_' + str(i) +
                                      '_' + str(v) + '.mp4',
                                      if_save=True)

    print('!!Completely Saved!!')
def generate_defensive_strategy(sess, graph, offense_input):
    """ Given one offensive input, generate 100 defensive strategies, and reture only one result with the hightest score . 

    Inputs 
    ------
    offense_input : float, shape=[lenght,13]
        lenght could be variable, [13] -> [ball's xyz * 1, offensive player's xy * 5]

    Returns
    -------
    defense_result : float, shape=[length,10]
        lenght could be variable, [10] -> [defensive player's xy * 5]
    """

    # placeholder tensor
    latent_input_t = graph.get_tensor_by_name('Generator/latent_input:0')
    team_a_t = graph.get_tensor_by_name('Generator/team_a:0')
    G_samples_t = graph.get_tensor_by_name('Critic/G_samples:0')
    matched_cond_t = graph.get_tensor_by_name('Critic/matched_cond:0')
    # result tensor
    result_t = graph.get_tensor_by_name(
        'Generator/G_inference/conv_result/conv1d/Maximum:0')
    critic_scores_t = graph.get_tensor_by_name(
        'Critic/C_inference_1/conv_output/Reshape:0')

    real_data = np.load('../../data/FEATURES-4.npy')
    # DataFactory
    data_factory = DataFactory(real_data)
    conditions = data_factory.normalize_offense(
        np.expand_dims(offense_input, axis=0))
    # given 100 latents generate 100 results on same condition at once
    conditions_duplicated = np.concatenate(
        [conditions for _ in range(100)], axis=0)
    # generate result
    latents = np.random.normal(
        0., 1., size=[100, 100])
    feed_dict = {
        latent_input_t: latents,
        team_a_t: conditions_duplicated
    }
    result = sess.run(
        result_t, feed_dict=feed_dict)
    # calculate em distance
    feed_dict = {
        G_samples_t: result,
        matched_cond_t: conditions_duplicated
    }
    critic_scores = sess.run(
        critic_scores_t, feed_dict=feed_dict)
    recoverd_A_fake_B = data_factory.recover_B(result)

    return recoverd_A_fake_B[np.argmax(critic_scores)]
예제 #3
0
def test():
    """
    test only
    """
    train_data = np.load(opt.data_path)
    data_factory = DataFactory(train_data)
    train_data = data_factory.fetch_ori_data()
    train_data = data_factory.recover_data(train_data)
    for i in range(opt.amount):
        plot_data(results_data[i:i + 1],
                  length=100,
                  file_path=opt.save_path + 'play_' + str(i) + '.mp4',
                  if_save=opt.save)
def mode_7(sess, graph, save_path):
    """ to draw feature map
    """
    # normalize
    real_data = np.load(FLAGS.data_path)[:, :FLAGS.seq_length, :, :]
    print('real_data.shape', real_data.shape)
    data_factory = DataFactory(real_data)
    target_data = np.load('FEATURES-6.npy')[:6]
    team_AB = np.concatenate(
        [
            # ball
            target_data[:, :, 0, :3].reshape(
                [target_data.shape[0], target_data.shape[1], 1 * 3]),
            # team A players
            target_data[:, :, 1:6, :2].reshape(
                [target_data.shape[0], target_data.shape[1], 5 * 2]),
            # team B players
            target_data[:, :, 6:11, :2].reshape(
                [target_data.shape[0], target_data.shape[1], 5 * 2])
        ], axis=-1
    )
    dummy_AB = np.zeros(shape=[128 - 6, 100, 23])
    team_AB = np.concatenate([team_AB, dummy_AB], axis=0)
    team_AB = data_factory.normalize(team_AB)
    team_A = team_AB[:, :, :13]
    team_B = team_AB[:, :, 13:]
    # placeholder tensor
    latent_input_t = graph.get_tensor_by_name('latent_input:0')
    team_a_t = graph.get_tensor_by_name('team_a:0')
    # result tensor
    conds_linear_t = graph.get_tensor_by_name(
        'Generator/G_inference/conds_linear/BiasAdd:0')

    if not os.path.exists(save_path):
        os.makedirs(save_path)

    # result collector
    latents = np.concatenate([z_samples(1)
                              for i in range(FLAGS.batch_size)], axis=0)
    feed_dict = {
        latent_input_t: latents,
        team_a_t: team_A
    }
    conds_linear = sess.run(conds_linear_t, feed_dict=feed_dict)
    for i in range(6):
        trace = go.Heatmap(z=conds_linear[i])
        data = [trace]
        plotly.offline.plot(data, filename=os.path.join(
            save_path, 'G_conds_linear' + str(i) + '.html'))

    print('!!Completely Saved!!')
예제 #5
0
def main(_):
    with tf.get_default_graph().as_default() as graph:
        real_data = np.load(FLAGS.data_path)[:, :FLAGS.seq_length, :, :]
        print('real_data.shape', real_data.shape)
        # normalize
        data_factory = DataFactory(real_data)
        train_data, valid_data = data_factory.fetch_data()
        print(train_data['A'].shape)
        print(valid_data['A'].shape)
        # config setting
        config = TrainingConfig()
        config.show()
        # train
        training(train_data, valid_data, data_factory, config, graph)
예제 #6
0
    def __init__(self, config, loss_for_G, graph):
        """ Build up the graph
        Inputs
        ------
        config : 
            * batch_size : mini batch size
            * log_dir : path to save training summary
            * learning_rate : adam's learning rate
            * seq_length : length of sequence during training
            * latent_dims : latent dimensions
            * penalty_lambda = gradient penalty's weight, ref from paper 'improved-wgan'
            * n_resblock : number of resblock in network body
            * if_feed_extra_info : basket position
            * residual_alpha : residual block = F(x) * residual_alpha + x
            * leaky_relu_alpha : tf.maximum(x, leaky_relu_alpha * x)
            * n_filters : number of filters in all ConV
        loss_for_G : function
            from Critic Network, given generative fake result, scores in return 
        graph : 
            tensorflow default graph
        """
        self.data_factory = DataFactory()
        # hyper-parameters
        self.batch_size = config.batch_size
        self.log_dir = config.log_dir
        self.learning_rate = config.learning_rate
        self.seq_length = config.seq_length
        self.latent_dims = config.latent_dims
        self.penalty_lambda = config.penalty_lambda
        self.latent_penalty_lambda = config.latent_penalty_lambda
        self.n_resblock = config.n_resblock
        self.if_feed_extra_info = config.if_feed_extra_info
        self.residual_alpha = config.residual_alpha
        self.leaky_relu_alpha = config.leaky_relu_alpha
        self.n_filters = config.n_filters

        # steps
        self.__global_steps = tf.train.get_or_create_global_step(graph=graph)
        with tf.name_scope('Generator'):
            self.__steps = tf.get_variable('G_steps', shape=[
            ], dtype=tf.int32, initializer=tf.zeros_initializer(dtype=tf.int32), trainable=False)
            # IO
            self.loss_for_G = loss_for_G
            self.__z = tf.placeholder(dtype=tf.float32, shape=[
                None, self.latent_dims], name='latent_input')
            self.__cond = tf.placeholder(dtype=tf.float32, shape=[
                None, None, 13], name='team_a')
            self.__real_data = tf.placeholder(dtype=tf.float32, shape=[
                None, None, 10], name='real_data')
            # adversarial learning : wgan
            self.__build_model()

            # summary
            self.__summary_op = tf.summary.merge(tf.get_collection('G'))
            self.__summary_histogram_op = tf.summary.merge(
                tf.get_collection('G_histogram'))
            self.__summary_weight_op = tf.summary.merge(
                tf.get_collection('G_weight'))
            self.summary_writer = tf.summary.FileWriter(
                self.log_dir + 'G', graph=graph)
def mode_6(sess, graph, save_path):
    """ to draw different length result
    """
    # normalize
    real_data = np.load(FLAGS.data_path)
    print('real_data.shape', real_data.shape)
    data_factory = DataFactory(real_data)
    target_data = np.load('FEATURES-7.npy')[:, :]
    team_AB = np.concatenate(
        [
            # ball
            target_data[:, :, 0, :3].reshape(
                [target_data.shape[0], target_data.shape[1], 1 * 3]),
            # team A players
            target_data[:, :, 1:6, :2].reshape(
                [target_data.shape[0], target_data.shape[1], 5 * 2]),
            # team B players
            target_data[:, :, 6:11, :2].reshape(
                [target_data.shape[0], target_data.shape[1], 5 * 2])
        ], axis=-1
    )
    team_AB = data_factory.normalize(team_AB)
    team_A = team_AB[:, :, :13]
    team_B = team_AB[:, :, 13:]
    # placeholder tensor
    latent_input_t = graph.get_tensor_by_name('latent_input:0')
    team_a_t = graph.get_tensor_by_name('team_a:0')
    # result tensor
    result_t = graph.get_tensor_by_name(
        'Generator/G_inference/conv_result/conv1d/Maximum:0')
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    # result collector
    latents = z_samples(team_AB.shape[0])
    feed_dict = {
        latent_input_t: latents,
        team_a_t: team_A
    }
    result_fake_B = sess.run(result_t, feed_dict=feed_dict)
    results_A_fake_B = np.concatenate([team_A, result_fake_B], axis=-1)
    results_A_fake_B = data_factory.recover_data(results_A_fake_B)
    for i in range(results_A_fake_B.shape[0]):
        game_visualizer.plot_data(
            results_A_fake_B[i], target_data.shape[1], file_path=save_path + str(i) + '.mp4', if_save=True)

    print('!!Completely Saved!!')
예제 #8
0
    def __init__(self, config):
        self.global_step = tf.train.get_or_create_global_step()
        self.lr_ = config.lr_
        self.batch_size = config.batch_size
        self.seq_length = config.seq_length
        self.latent_dims = config.latent_dims
        self.n_filters = config.n_filters
        self.features_ = config.features_
        self.features_d = config.features_d
        self.keep_prob = config.keep_prob
        self.n_resblock = config.n_resblock

        self.data_factory = DataFactory()

        #Real offence
        self.input_ = tf.placeholder(tf.float32,
                                     shape=[None, None, self.features_],
                                     name='Real')
        #Real Defence
        self.input_d = tf.placeholder(tf.float32,
                                      shape=[None, None, self.features_d],
                                      name='Real_defence')

        self.ground_feature = tf.placeholder(tf.float32,
                                             shape=[None, None, 6],
                                             name='Real_feat')
        #Condition Data
        self.seq_input = tf.placeholder(tf.float32,
                                        shape=[None, None, self.features_],
                                        name='Cond_input')

        self.seq_feature = tf.placeholder(tf.float32,
                                          shape=[None, None, 6],
                                          name='Seq_feat')

        self.z_sample = tf.placeholder(tf.float32,
                                       shape=[None, self.latent_dims],
                                       name='Latent')

        self.network_()
        self.loss_()

        init_ = tf.global_variables_initializer()
        self.sess = tf.Session()
        self.sess.run(init_)
        self.saver = tf.train.Saver(max_to_keep=0)

        # summary collection
        self.G_summaries = tf.summary.merge(tf.get_collection('G'))
        self.D_summaries = tf.summary.merge(tf.get_collection('D'))
        # summary writer
        self.G_summary_writer = tf.summary.FileWriter(
            os.path.join(config.folder_path, 'Log/G'),
            graph=tf.get_default_graph())
        self.D_summary_writer = tf.summary.FileWriter(
            os.path.join(config.folder_path, 'Log/D'))
        self.D_valid_summary_writer = tf.summary.FileWriter(
            os.path.join(config.folder_path, 'Log/D_valid'))
def main(_):
    real_data = np.load(FLAGS.data_path)[:, :FLAGS.seq_length, :, :]
    print('real_data.shape', real_data.shape)
    # normalize
    data_factory = DataFactory(real_data)
    train_data, valid_data = data_factory.fetch_data()
    print(train_data['A'].shape)
    print(valid_data['A'].shape)
    # config setting
    config = TrainingConfig()

    baseline_graph = None
    if FLAGS.baseline_checkpoint is not None:
        baseline_graph = tf.Graph()
    default_graph = tf.Graph()

    training(train_data, valid_data, data_factory, config, default_graph,
             baseline_graph)
예제 #10
0
    def __init__(self, config, graph):
        """ TO build up the graph
        Inputs
        ------
        config : 
            * batch_size : mini batch size
            * log_dir : path to save training summary
            * learning_rate : adam's learning rate
            * hidden_size : number of hidden units in LSTM
            * rnn_layers : number of stacked LSTM 
            * seq_length : length of LSTM
            * latent_dims : dimensions of latent feature
            * penalty_lambda = gradient penalty's weight, ref from  paper of 'improved-wgan'
        graph : 
            tensorflow default graph
        """
        self.data_factory = DataFactory()
        # hyper-parameters
        self.batch_size = config.batch_size
        self.log_dir = config.log_dir
        self.learning_rate = config.learning_rate
        self.hidden_size = config.hidden_size
        self.rnn_layers = config.rnn_layers
        self.seq_length = config.seq_length
        self.latent_dims = config.latent_dims
        self.penalty_lambda = config.penalty_lambda
        self.if_log_histogram = config.if_log_histogram
        self.n_resblock = config.n_resblock
        # steps
        self.__global_steps = tf.train.get_or_create_global_step(graph=graph)
        self.__steps = 0
        # data
        self.__G_samples = tf.placeholder(
            dtype=tf.float32,
            shape=[self.batch_size, self.seq_length, 10],
            name='G_samples')
        self.__X = tf.placeholder(dtype=tf.float32,
                                  shape=[self.batch_size, self.seq_length, 10],
                                  name='real_data')
        # TODO mismatched conditional constraints
        # self.__mismatched_cond = tf.placeholder(dtype=tf.float32, shape=[
        #     self.batch_size, self.seq_length, self.num_features], name='mismatched_cond)
        self.__matched_cond = tf.placeholder(
            dtype=tf.float32,
            shape=[self.batch_size, self.seq_length, 13],
            name='matched_cond')
        # adversarial learning : wgan
        self.__build_wgan()

        # summary
        self.__summary_op = tf.summary.merge(tf.get_collection('C'))
        self.__summary_valid_op = tf.summary.merge(
            tf.get_collection('C_valid'))
        self.summary_writer = tf.summary.FileWriter(self.log_dir + 'C')
        self.valid_summary_writer = tf.summary.FileWriter(self.log_dir +
                                                          'C_valid')
예제 #11
0
    def __init__(self, config, critic_inference, graph):
        """ TO build up the graph
        Inputs
        ------
        config : 
            * batch_size : mini batch size
            * log_dir : path to save training summary
            * learning_rate : adam's learning rate
            * seq_length : length of LSTM
            * latent_dims : dimensions of latent feature
            * penalty_lambda = gradient penalty's weight, ref from  paper of 'improved-wgan'
        graph : 
            tensorflow default graph
        """
        self.data_factory = DataFactory()
        # hyper-parameters
        self.batch_size = config.batch_size
        self.log_dir = config.log_dir
        self.learning_rate = config.learning_rate
        self.seq_length = config.seq_length
        self.latent_dims = config.latent_dims

        self.num_layers = config.num_layers
        self.hidden_size = config.hidden_size
        self.num_features = config.num_features
        self.if_feed_prev = True

        self.penalty_lambda = config.penalty_lambda
        self.latent_penalty_lambda = config.latent_penalty_lambda
        self.n_resblock = config.n_resblock
        self.if_feed_extra_info = config.if_feed_extra_info
        self.residual_alpha = config.residual_alpha
        self.leaky_relu_alpha = config.leaky_relu_alpha
        # steps
        self.__global_steps = tf.train.get_or_create_global_step(graph=graph)
        self.__steps = 0
        # IO
        self.critic = critic_inference
        self.__z = tf.placeholder(dtype=tf.float32,
                                  shape=[self.batch_size, self.latent_dims],
                                  name='latent_input')
        self.__cond = tf.placeholder(dtype=tf.float32,
                                     shape=[self.batch_size, None, 13],
                                     name='team_a')
        # adversarial learning : wgan
        self.__build_model()

        # summary
        self.__summary_op = tf.summary.merge(tf.get_collection('G'))
        self.__summary_histogram_op = tf.summary.merge(
            tf.get_collection('G_histogram'))
        # self.__summary_weight_op = tf.summary.merge(
        # tf.get_collection('G_weight'))
        self.summary_writer = tf.summary.FileWriter(self.log_dir + 'G',
                                                    graph=graph)
예제 #12
0
def main(args):
    torch.cuda.set_device(0)
    torch.backends.cudnn.benchmark = False

    num_anno = torch.tensor(
        HICODet(None,
                anno_file=os.path.join(
                    args.data_root,
                    'instances_train2015.json')).anno_interaction)
    rare = torch.nonzero(num_anno < 10).squeeze(1)
    non_rare = torch.nonzero(num_anno >= 10).squeeze(1)

    dataloader = DataLoader(dataset=DataFactory(
        name='hicodet',
        partition=args.partition,
        data_root=args.data_root,
        detection_root=args.detection_dir,
        box_score_thresh_h=args.human_thresh,
        box_score_thresh_o=args.object_thresh),
                            collate_fn=custom_collate,
                            batch_size=1,
                            num_workers=args.num_workers,
                            pin_memory=True)

    net = SpatioAttentiveGraph(dataloader.dataset.dataset.object_to_verb,
                               49,
                               num_iterations=args.num_iter,
                               max_human=args.max_human,
                               max_object=args.max_object)
    epoch = 0
    if os.path.exists(args.model_path):
        print("Loading model from ", args.model_path)
        checkpoint = torch.load(args.model_path, map_location="cpu")
        net.load_state_dict(checkpoint['model_state_dict'])
        epoch = checkpoint["epoch"]
    elif len(args.model_path):
        print("\nWARNING: The given model path does not exist. "
              "Proceed to use a randomly initialised model.\n")

    net.cuda()
    timer = pocket.utils.HandyTimer(maxlen=1)

    with timer:
        test_ap = test(net, dataloader)
    print("Model at epoch: {} | time elapsed: {:.2f}s\n"
          "Full: {:.4f}, rare: {:.4f}, non-rare: {:.4f}".format(
              epoch, timer[0], test_ap.mean(), test_ap[rare].mean(),
              test_ap[non_rare].mean()))
예제 #13
0
def main(args):
    torch.cuda.set_device(0)
    torch.backends.cudnn.benchmark = False

    if not os.path.exists(args.cache_dir):
        os.makedirs(args.cache_dir)

    dataloader = DataLoader(
        dataset=DataFactory(
            name=args.dataset, partition=args.partition,
            data_root=args.data_root,
            detection_root=args.detection_dir,
        ), collate_fn=custom_collate, batch_size=1,
        num_workers=args.num_workers, pin_memory=True
    )

    if args.dataset == 'hicodet':
        object_to_target = dataloader.dataset.dataset.object_to_verb
        human_idx = 49
        num_classes = 117
    elif args.dataset == 'vcoco':
        object_to_target = dataloader.dataset.dataset.object_to_action
        human_idx = 1
        num_classes = 24
    net = SCG(
        object_to_target, human_idx, num_classes=num_classes,
        num_iterations=args.num_iter,
        max_human=args.max_human, max_object=args.max_object,
        box_score_thresh=args.box_score_thresh
    )
    if os.path.exists(args.model_path):
        print("Loading model from ", args.model_path)
        checkpoint = torch.load(args.model_path, map_location="cpu")
        net.load_state_dict(checkpoint['model_state_dict'])
    elif len(args.model_path):
        print("\nWARNING: The given model path does not exist. "
            "Proceed to use a randomly initialised model.\n")

    net.cuda()
    
    if args.dataset == 'hicodet':
        with open(os.path.join(args.data_root, 'coco80tohico80.json'), 'r') as f:
            coco2hico = json.load(f)
        inference_hicodet(net, dataloader, coco2hico, args.cache_dir)
    elif args.dataset == 'vcoco':
        inference_vcoco(net, dataloader, args.cache_dir)
예제 #14
0
def main(args):

    dataset = DataFactory(
        name='hicodet',
        partition=args.partition,
        data_root=args.data_root,
        detection_root=args.detection_dir,
    )
    dataloader = torch.utils.data.DataLoader(dataset,
                                             collate_fn=custom_collate,
                                             batch_size=4,
                                             shuffle=False)

    net = SCG(dataset.dataset.object_to_verb,
              49,
              num_iterations=args.num_iter,
              box_score_thresh=args.box_score_thresh)
    net.eval()

    if os.path.exists(args.model_path):
        print("\nLoading model from ", args.model_path)
        checkpoint = torch.load(args.model_path, map_location="cpu")
        net.load_state_dict(checkpoint['model_state_dict'])
    elif len(args.model_path):
        print("\nWARNING: The given model path does not exist. "
              "Proceed to use a randomly initialised model.\n")
    else:
        print("\nProceed with a randomly initialised model\n")

    # iterator = iter(dataloader)
    # image, detection, target = next(iterator)

    image, detection, target = dataset[args.index]
    image = [image]
    detection = [detection]
    target = [target]

    output = net(image, detection, target)
    visualise_entire_image(dataset, output[0])
예제 #15
0
def main(_):
    with tf.get_default_graph().as_default() as graph:
        real_data = np.load(os.path.join(
            FLAGS.data_path, '50Real.npy'))[:, :FLAGS.seq_length, :, :]
        seq_data = np.load(os.path.join(FLAGS.data_path, '50Seq.npy'))
        features_ = np.load(os.path.join(FLAGS.data_path, 'SeqCond.npy'))
        real_feat = np.load(os.path.join(FLAGS.data_path, 'RealCond.npy'))

        print("Real Data: ", real_data.shape)
        print("Seq Data: ", seq_data.shape)
        print("Real Feat: ", real_feat.shape)
        print("Seq Feat: ", features_.shape)

        data_factory = DataFactory(real_data=real_data,
                                   seq_data=seq_data,
                                   features_=features_,
                                   real_feat=real_feat)

        config = Training_config()
        config.show()
        trainer = Trainer(data_factory, config)
        trainer()
def main(rank, args):

    dist.init_process_group(backend="nccl",
                            init_method="env://",
                            world_size=args.world_size,
                            rank=rank)

    trainset = DataFactory(name=args.dataset,
                           partition=args.partitions[0],
                           data_root=args.data_root,
                           detection_root=args.train_detection_dir,
                           flip=True)

    valset = DataFactory(name=args.dataset,
                         partition=args.partitions[1],
                         data_root=args.data_root,
                         detection_root=args.val_detection_dir)

    train_loader = DataLoader(dataset=trainset,
                              collate_fn=custom_collate,
                              batch_size=args.batch_size,
                              num_workers=args.num_workers,
                              pin_memory=True,
                              sampler=DistributedSampler(
                                  trainset,
                                  num_replicas=args.world_size,
                                  rank=rank))

    val_loader = DataLoader(dataset=valset,
                            collate_fn=custom_collate,
                            batch_size=args.batch_size,
                            num_workers=args.num_workers,
                            pin_memory=True,
                            sampler=DistributedSampler(
                                valset,
                                num_replicas=args.world_size,
                                rank=rank))

    # Fix random seed for model synchronisation
    torch.manual_seed(args.random_seed)

    if args.dataset == 'hicodet':
        object_to_target = train_loader.dataset.dataset.object_to_verb
        human_idx = 49
        num_classes = 117
    elif args.dataset == 'vcoco':
        object_to_target = train_loader.dataset.dataset.object_to_action
        human_idx = 1
        num_classes = 24
    net = SCG(object_to_target,
              human_idx,
              num_classes=num_classes,
              num_iterations=args.num_iter,
              postprocess=False,
              max_human=args.max_human,
              max_object=args.max_object,
              box_score_thresh=args.box_score_thresh,
              distributed=True)

    if os.path.exists(args.checkpoint_path):
        print("=> Rank {}: continue from saved checkpoint".format(rank),
              args.checkpoint_path)
        checkpoint = torch.load(args.checkpoint_path, map_location='cpu')
        net.load_state_dict(checkpoint['model_state_dict'])
        optim_state_dict = checkpoint['optim_state_dict']
        sched_state_dict = checkpoint['scheduler_state_dict']
        epoch = checkpoint['epoch']
        iteration = checkpoint['iteration']
    else:
        print(
            "=> Rank {}: start from a randomly initialised model".format(rank))
        optim_state_dict = None
        sched_state_dict = None
        epoch = 0
        iteration = 0

    engine = CustomisedDLE(net,
                           train_loader,
                           val_loader,
                           num_classes=num_classes,
                           print_interval=args.print_interval,
                           cache_dir=args.cache_dir)
    # Seperate backbone parameters from the rest
    param_group_1 = []
    param_group_2 = []
    for k, v in engine.fetch_state_key('net').named_parameters():
        if v.requires_grad:
            if k.startswith('module.backbone'):
                param_group_1.append(v)
            elif k.startswith('module.interaction_head'):
                param_group_2.append(v)
            else:
                raise KeyError(f"Unknown parameter name {k}")
    # Fine-tune backbone with lower learning rate
    optim = torch.optim.AdamW([{
        'params': param_group_1,
        'lr': args.learning_rate * args.lr_decay
    }, {
        'params': param_group_2
    }],
                              lr=args.learning_rate,
                              weight_decay=args.weight_decay)
    lambda1 = lambda epoch: 1. if epoch < args.milestones[0] else args.lr_decay
    lambda2 = lambda epoch: 1. if epoch < args.milestones[0] else args.lr_decay
    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
        optim, lr_lambda=[lambda1, lambda2])
    # Override optimiser and learning rate scheduler
    engine.update_state_key(optimizer=optim, lr_scheduler=lr_scheduler)
    engine.update_state_key(epoch=epoch, iteration=iteration)

    engine(args.num_epochs)
    def __init__(self, config, graph, if_training=True):
        """ Build up the graph
        Inputs
        ------
        config :
            * batch_size : mini batch size
            * log_dir : path to save training summary
            * learning_rate : adam's learning rate
            * seq_length : length of sequence during training
            * penalty_lambda = gradient penalty's weight, ref from paper 'improved-wgan'
            * n_resblock : number of resblock in network body
            * if_handcraft_features : if_handcraft_features
            * residual_alpha : residual block = F(x) * residual_alpha + x
            * leaky_relu_alpha : tf.maximum(x, leaky_relu_alpha * x)
            * openshot_penalty_lambda : Critic = Critic - openshot_penalty_lambda * open_shot_score
            * if_use_mismatched : if True, negative scores = mean of (fake_scores + mismatched_scores)
            * n_filters : number of filters in all ConV
        graph :
            tensorflow default graph
        """
        self.data_factory = DataFactory()
        # hyper-parameters
        self.batch_size = config.batch_size
        self.log_dir = config.log_dir
        self.learning_rate = config.learning_rate
        self.seq_length = config.seq_length
        self.penalty_lambda = config.penalty_lambda
        self.n_resblock = config.n_resblock
        self.if_handcraft_features = config.if_handcraft_features
        self.residual_alpha = config.residual_alpha
        self.leaky_relu_alpha = config.leaky_relu_alpha
        self.openshot_penalty_lambda = config.openshot_penalty_lambda
        self.if_use_mismatched = config.if_use_mismatched
        self.n_filters = config.n_filters
        self.if_training = if_training

        # steps
        self.__global_steps = tf.train.get_or_create_global_step(graph=graph)
        with tf.name_scope('Critic'):
            self.__steps = tf.get_variable(
                'C_steps',
                shape=[],
                dtype=tf.int32,
                initializer=tf.zeros_initializer(dtype=tf.int32),
                trainable=False)
            # data
            self.__G_samples = tf.placeholder(dtype=tf.float32,
                                              shape=[None, None, 10],
                                              name='G_samples')
            self.__real_data = tf.placeholder(dtype=tf.float32,
                                              shape=[None, None, 10],
                                              name='real_data')
            self.__matched_cond = tf.placeholder(dtype=tf.float32,
                                                 shape=[None, None, 13],
                                                 name='matched_cond')
            self.__mismatched_cond = tf.random_shuffle(self.__matched_cond)
            # adversarial learning : wgan
            self.__build_model()

            # summary
            if self.if_training:
                self.__summary_op = tf.summary.merge(tf.get_collection('C'))
                self.__summary_histogram_op = tf.summary.merge(
                    tf.get_collection('C_histogram'))
                self.__summary_valid_op = tf.summary.merge(
                    tf.get_collection('C_valid'))
                self.summary_writer = tf.summary.FileWriter(self.log_dir + 'C')
                self.valid_summary_writer = tf.summary.FileWriter(
                    self.log_dir + 'C_valid')
            else:
                self.baseline_summary_writer = tf.summary.FileWriter(
                    self.log_dir + 'Baseline_C')
class C_MODEL(object):
    """ Model of Critic Network
    """
    def __init__(self, config, graph, if_training=True):
        """ Build up the graph
        Inputs
        ------
        config :
            * batch_size : mini batch size
            * log_dir : path to save training summary
            * learning_rate : adam's learning rate
            * seq_length : length of sequence during training
            * penalty_lambda = gradient penalty's weight, ref from paper 'improved-wgan'
            * n_resblock : number of resblock in network body
            * if_handcraft_features : if_handcraft_features
            * residual_alpha : residual block = F(x) * residual_alpha + x
            * leaky_relu_alpha : tf.maximum(x, leaky_relu_alpha * x)
            * openshot_penalty_lambda : Critic = Critic - openshot_penalty_lambda * open_shot_score
            * if_use_mismatched : if True, negative scores = mean of (fake_scores + mismatched_scores)
            * n_filters : number of filters in all ConV
        graph :
            tensorflow default graph
        """
        self.data_factory = DataFactory()
        # hyper-parameters
        self.batch_size = config.batch_size
        self.log_dir = config.log_dir
        self.learning_rate = config.learning_rate
        self.seq_length = config.seq_length
        self.penalty_lambda = config.penalty_lambda
        self.n_resblock = config.n_resblock
        self.if_handcraft_features = config.if_handcraft_features
        self.residual_alpha = config.residual_alpha
        self.leaky_relu_alpha = config.leaky_relu_alpha
        self.openshot_penalty_lambda = config.openshot_penalty_lambda
        self.if_use_mismatched = config.if_use_mismatched
        self.n_filters = config.n_filters
        self.if_training = if_training

        # steps
        self.__global_steps = tf.train.get_or_create_global_step(graph=graph)
        with tf.name_scope('Critic'):
            self.__steps = tf.get_variable(
                'C_steps',
                shape=[],
                dtype=tf.int32,
                initializer=tf.zeros_initializer(dtype=tf.int32),
                trainable=False)
            # data
            self.__G_samples = tf.placeholder(dtype=tf.float32,
                                              shape=[None, None, 10],
                                              name='G_samples')
            self.__real_data = tf.placeholder(dtype=tf.float32,
                                              shape=[None, None, 10],
                                              name='real_data')
            self.__matched_cond = tf.placeholder(dtype=tf.float32,
                                                 shape=[None, None, 13],
                                                 name='matched_cond')
            self.__mismatched_cond = tf.random_shuffle(self.__matched_cond)
            # adversarial learning : wgan
            self.__build_model()

            # summary
            if self.if_training:
                self.__summary_op = tf.summary.merge(tf.get_collection('C'))
                self.__summary_histogram_op = tf.summary.merge(
                    tf.get_collection('C_histogram'))
                self.__summary_valid_op = tf.summary.merge(
                    tf.get_collection('C_valid'))
                self.summary_writer = tf.summary.FileWriter(self.log_dir + 'C')
                self.valid_summary_writer = tf.summary.FileWriter(
                    self.log_dir + 'C_valid')
            else:
                self.baseline_summary_writer = tf.summary.FileWriter(
                    self.log_dir + 'Baseline_C')

    def __build_model(self):
        self.real_scores = self.inference(self.__real_data,
                                          self.__matched_cond)
        self.fake_scores = self.inference(self.__G_samples,
                                          self.__matched_cond,
                                          reuse=True)
        if self.if_use_mismatched:
            mismatched_scores = self.inference(self.__real_data,
                                               self.__mismatched_cond,
                                               reuse=True)
            neg_scores = (self.fake_scores + mismatched_scores) / 2.0
        else:
            neg_scores = self.fake_scores

        if self.if_training:
            # loss function
            self.__loss = self.__loss_fn(self.__real_data, self.__G_samples,
                                         neg_scores, self.real_scores,
                                         self.penalty_lambda)
            theta = libs.get_var_list('C')
            with tf.name_scope('optimizer') as scope:
                # Critic train one iteration, step++
                assign_add_ = tf.assign_add(self.__steps, 1)
                with tf.control_dependencies([assign_add_]):
                    optimizer = tf.train.AdamOptimizer(
                        learning_rate=self.learning_rate, beta1=0.5, beta2=0.9)
                    grads = tf.gradients(self.__loss, theta)
                    grads = list(zip(grads, theta))
                    self.__train_op = optimizer.apply_gradients(
                        grads_and_vars=grads, global_step=self.__global_steps)
            # histogram logging
            for grad, var in grads:
                tf.summary.histogram(var.name + '_gradient',
                                     grad,
                                     collections=['C_histogram'])
        else:
            f_fake = tf.reduce_mean(self.fake_scores)
            f_real = tf.reduce_mean(self.real_scores)
            with tf.name_scope('C_loss') as scope:
                self.EM_dist = f_real - f_fake
                self.summary_em = tf.summary.scalar('Earth Moving Distance',
                                                    self.EM_dist)

    def inference(self, inputs, conds, reuse=False):
        """
        Inputs
        ------
        inputs : tensor, float, shape=[batch_size, seq_length=100, features=10]
            real(from data) or fake(from G)
        conds : tensor, float, shape=[batch_size, swq_length=100, features=13]
            conditions, ball and team A
        reuse : bool, optional, defalt value is False
            if share variable

        Return
        ------
        score : float
            real(from data) or fake(from G)
        """
        with tf.variable_scope('C_inference', reuse=reuse):
            concat_ = tf.concat([conds, inputs], axis=-1)
            if self.if_handcraft_features:
                concat_ = self.data_factory.extract_features(concat_)
            with tf.variable_scope('conv_input') as scope:
                conv_input = tf.layers.conv1d(
                    inputs=concat_,
                    filters=self.n_filters,
                    kernel_size=5,
                    strides=1,
                    padding='same',
                    activation=libs.leaky_relu,
                    kernel_initializer=layers.xavier_initializer(),
                    bias_initializer=tf.zeros_initializer())
            # residual block
            next_input = conv_input
            for i in range(self.n_resblock):
                res_block = libs.residual_block(
                    'Res' + str(i),
                    next_input,
                    n_filters=self.n_filters,
                    n_layers=2,
                    residual_alpha=self.residual_alpha,
                    leaky_relu_alpha=self.leaky_relu_alpha)
                next_input = res_block
            with tf.variable_scope('conv_output') as scope:
                normed = layers.layer_norm(next_input)
                nonlinear = libs.leaky_relu(normed)
                conv_output = tf.layers.conv1d(
                    inputs=nonlinear,
                    filters=1,
                    kernel_size=5,
                    strides=1,
                    padding='same',
                    activation=libs.leaky_relu,
                    kernel_initializer=layers.xavier_initializer(),
                    bias_initializer=tf.zeros_initializer())
                conv_output = tf.reduce_mean(conv_output, axis=1)
                final_ = tf.reshape(conv_output, shape=[-1])
            return final_

    def loss_for_G(self, reals, fakes, conds, latent_weight_penalty):
        """ 
        Param
        -----
        reals : 
        fakes : 
        conds : 
        latent_weight_penalty : 
        """
        openshot_penalty_lambda = tf.constant(self.openshot_penalty_lambda)
        openshot_penalty = self.__open_shot_penalty(reals,
                                                    conds,
                                                    fakes,
                                                    if_log=True)
        fake_scores = self.inference(fakes, conds, reuse=True)
        scale_ = tf.abs(tf.reduce_mean(fake_scores))
        loss = - tf.reduce_mean(fake_scores) + scale_ * \
            openshot_penalty_lambda * openshot_penalty + scale_ * latent_weight_penalty
        return loss

    def __open_shot_penalty(self, reals, conds, fakes, if_log):
        """
        """
        real_os_penalty = self.__open_shot_score(reals,
                                                 conds,
                                                 if_log=if_log,
                                                 log_scope_name='real')
        fake_os_penalty = self.__open_shot_score(fakes,
                                                 conds,
                                                 if_log=if_log,
                                                 log_scope_name='fake')
        return tf.abs(real_os_penalty - fake_os_penalty)

    def __open_shot_score(self, inputs, conds, if_log, log_scope_name=''):
        """
        log_scope_name : string
            scope name for open_shot_score
        """
        with tf.name_scope('open_shot_score') as scope:
            # calculate the open shot penalty on each frames
            ball_pos = tf.reshape(
                conds[:, :, :2],
                shape=[self.batch_size, self.seq_length, 1, 2])
            teamB_pos = tf.reshape(
                inputs, shape=[self.batch_size, self.seq_length, 5, 2])
            basket_right_x = tf.constant(
                self.data_factory.BASKET_RIGHT[0],
                dtype=tf.float32,
                shape=[self.batch_size, self.seq_length, 1, 1])
            basket_right_y = tf.constant(
                self.data_factory.BASKET_RIGHT[1],
                dtype=tf.float32,
                shape=[self.batch_size, self.seq_length, 1, 1])
            basket_pos = tf.concat([basket_right_x, basket_right_y], axis=-1)
            # open shot penalty = amin((theta + 1.0) * (dist_ball_2_teamB + 1.0))
            vec_ball_2_teamB = ball_pos - teamB_pos
            vec_ball_2_basket = ball_pos - basket_pos
            b2teamB_dot_b2basket = tf.matmul(vec_ball_2_teamB,
                                             vec_ball_2_basket,
                                             transpose_b=True)
            b2teamB_dot_b2basket = tf.reshape(
                b2teamB_dot_b2basket,
                shape=[self.batch_size, self.seq_length, 5])
            dist_ball_2_teamB = tf.norm(vec_ball_2_teamB,
                                        ord='euclidean',
                                        axis=-1)
            dist_ball_2_basket = tf.norm(vec_ball_2_basket,
                                         ord='euclidean',
                                         axis=-1)

            theta = tf.acos(b2teamB_dot_b2basket /
                            (dist_ball_2_teamB * dist_ball_2_basket + 1e-3))
            open_shot_score_all = (theta + 1.0) * (dist_ball_2_teamB + 1.0)

            # add
            # one_sub_cosine = 1 - b2teamB_dot_b2basket / \
            #     (dist_ball_2_teamB * dist_ball_2_basket)
            # open_shot_score_all = one_sub_cosine + dist_ball_2_teamB

            open_shot_score_min = tf.reduce_min(open_shot_score_all, axis=-1)
            open_shot_score = tf.reduce_mean(open_shot_score_min)

            # too close penalty
            too_close_penalty = 0.0
            for i in range(5):
                vec = tf.subtract(teamB_pos[:, :, i:i + 1], teamB_pos)
                dist = tf.sqrt((vec[:, :, :, 0] + 1e-8)**2 +
                               (vec[:, :, :, 1] + 1e-8)**2)
                too_close_penalty -= tf.reduce_mean(dist)

        if if_log:
            with tf.name_scope(log_scope_name):
                tf.summary.scalar('open_shot_score',
                                  open_shot_score,
                                  collections=['G'])
                tf.summary.scalar('too_close_penalty',
                                  too_close_penalty,
                                  collections=['G'])
        return open_shot_score + too_close_penalty

    def __loss_fn(self, real_data, G_sample, fake_scores, real_scores,
                  penalty_lambda):
        """ Critic loss

        Params
        ------
        real_data : tensor, float, shape=[batch_size, seq_length, features=10]
            real data, team B, defensive players
        G_sample : tensor, float, shape=[batch_size, seq_length, features=10]
            fake data, team B, defensive players
        fake_scores : tensor, float, shape=[batch_size]
            result from inference given fake data
        real_scores : tensor, float, shape=[batch_size]
            result from inference given real data
        penalty_lambda : float
            gradient penalty's weight, ref from paper 'improved-wgan'

        Return
        ------
        loss : float, shape=[]
            the mean loss of one batch
        """
        with tf.name_scope('C_loss') as scope:
            # grad_pen, base on paper (Improved-WGAN)
            epsilon = tf.random_uniform([self.batch_size, 1, 1],
                                        minval=0.0,
                                        maxval=1.0)
            X_inter = epsilon * real_data + (1.0 - epsilon) * G_sample
            if self.if_use_mismatched:
                cond_inter = epsilon * self.__matched_cond + \
                    (1.0 - epsilon) * self.__mismatched_cond
            else:
                cond_inter = self.__matched_cond

            grad = tf.gradients(
                self.inference(X_inter, cond_inter, reuse=True), [X_inter])[0]
            sum_ = tf.reduce_sum(tf.square(grad), axis=[1, 2])
            grad_norm = tf.sqrt(sum_)
            grad_pen = penalty_lambda * tf.reduce_mean(
                tf.square(grad_norm - 1.0))
            EM_dist = tf.identity(real_scores - fake_scores, name="EM_dist")
            f_fake = tf.reduce_mean(fake_scores)
            f_real = tf.reduce_mean(real_scores)
            # Earth Moving Distance
            loss = f_fake - f_real + grad_pen

            # logging
            tf.summary.scalar('C_loss', loss, collections=['C', 'C_valid'])
            tf.summary.scalar('F_real', f_real, collections=['C'])
            tf.summary.scalar('F_fake', f_fake, collections=['C'])
            tf.summary.scalar('Earth Moving Distance',
                              f_real - f_fake,
                              collections=['C', 'C_valid'])
            tf.summary.scalar('grad_pen', grad_pen, collections=['C'])

        return loss

    def step(self, sess, G_samples, real_data, conditions):
        """ train one batch on C

        Params
        ------
        sess : tensorflow Session
        G_samples : float, shape=[batch_size, seq_length, features=10]
            fake data, team B, defensive players
        real_data : float, shape=[batch_size, seq_length, features=10]
            real data, team B, defensive players
        conditions : float, shape=[batch_size, seq_length, features=13]
            real data, team A, offensive players

        Returns
        -------
        loss : float
            batch mean loss
        global_steps : int
            global steps
        """
        feed_dict = {
            self.__G_samples: G_samples,
            self.__matched_cond: conditions,
            self.__real_data: real_data
        }
        steps, summary, loss, global_steps, _ = sess.run([
            self.__steps, self.__summary_op, self.__loss, self.__global_steps,
            self.__train_op
        ],
                                                         feed_dict=feed_dict)
        # log
        self.summary_writer.add_summary(summary, global_step=global_steps)
        if (steps - 1) % 1000 == 0:
            summary_histogram = sess.run(self.__summary_histogram_op,
                                         feed_dict=feed_dict)
            self.summary_writer.add_summary(summary_histogram,
                                            global_step=global_steps)

        return loss, global_steps

    def log_valid_loss(self, sess, G_samples, real_data, conditions):
        """ get one batch validation loss

        Params
        ------
        sess : tensorflow Session
        G_samples : float, shape=[batch_size, seq_length, features=10]
            fake data, team B, defensive players
        real_data : float, shape=[batch_size, seq_length, features=10]
            real data, team B, defensive players
        conditions : float, shape=[batch_size, seq_length, features=13]
            real data, team A, offensive players

        Returns
        -------
        loss : float
            validation batch mean loss
        """
        feed_dict = {
            self.__G_samples: G_samples,
            self.__matched_cond: conditions,
            self.__real_data: real_data
        }
        summary, loss, global_steps = sess.run(
            [self.__summary_valid_op, self.__loss, self.__global_steps],
            feed_dict=feed_dict)
        # log
        self.valid_summary_writer.add_summary(summary,
                                              global_step=global_steps)
        return loss

    def eval_EM_distance(self, sess, G_samples, real_data, conditions,
                         global_steps):
        """ 
        """
        feed_dict = {
            self.__G_samples: G_samples,
            self.__matched_cond: conditions,
            self.__real_data: real_data
        }
        _, summary = sess.run([self.EM_dist, self.summary_em],
                              feed_dict=feed_dict)
        self.baseline_summary_writer.add_summary(summary,
                                                 global_step=global_steps)
예제 #19
0
    def __init__(self, config, graph, if_training=True):
        """ TO build up the graph
        Inputs
        ------
        config : 
            * batch_size : mini batch size
            * log_dir : path to save training summary
            * learning_rate : adam's learning rate
            * seq_length : length of LSTM
            * latent_dims : dimensions of latent feature
            * penalty_lambda = gradient penalty's weight, ref from  paper of 'improved-wgan'
        graph : 
            tensorflow default graph
        """
        self.data_factory = DataFactory()
        # hyper-parameters
        self.batch_size = config.batch_size
        self.log_dir = config.log_dir
        self.learning_rate = config.learning_rate
        self.seq_length = config.seq_length
        self.latent_dims = config.latent_dims
        # RNN parameters
        self.num_features = config.num_features
        self.num_layers = config.num_layers
        self.hidden_size = config.hidden_size
        self.if_training = if_training

        self.penalty_lambda = config.penalty_lambda
        self.if_handcraft_features = config.if_handcraft_features
        self.leaky_relu_alpha = config.leaky_relu_alpha
        self.heuristic_penalty_lambda = config.heuristic_penalty_lambda
        self.if_use_mismatched = config.if_use_mismatched
        self.if_trainable_lambda = config.if_trainable_lambda
        # steps
        self.__global_steps = tf.train.get_or_create_global_step(graph=graph)
        self.__steps = 0
        # data
        self.__G_samples = tf.placeholder(
            dtype=tf.float32,
            shape=[self.batch_size, self.seq_length, 10],
            name='G_samples')
        self.__real_data = tf.placeholder(
            dtype=tf.float32,
            shape=[self.batch_size, self.seq_length, 10],
            name='real_data')
        self.__matched_cond = tf.placeholder(
            dtype=tf.float32,
            shape=[self.batch_size, self.seq_length, 13],
            name='matched_cond')
        self.__mismatched_cond = tf.random_shuffle(self.__matched_cond)
        # adversarial learning : wgan
        self.__build_model()

        # summary
        if self.if_training:
            self.__summary_op = tf.summary.merge(tf.get_collection('C'))
            # self.__summary_histogram_op = tf.summary.merge(
            # tf.get_collection('C_histogram'))
            self.__summary_valid_op = tf.summary.merge(
                tf.get_collection('C_valid'))
            self.summary_writer = tf.summary.FileWriter(self.log_dir + 'C')
            self.valid_summary_writer = tf.summary.FileWriter(self.log_dir +
                                                              'C_valid')
        else:
            self.baseline_summary_writer = tf.summary.FileWriter(self.log_dir +
                                                                 'Baseline_C')
def rnn():
    """ to collect results vary in length
    Saved Result
    ------------
    results_A_fake_B : float, numpy ndarray, shape=[n_latents=100, n_conditions=100, length=100, features=23]
        Real A + Fake B
    results_A_real_B : float, numpy ndarray, shape=[n_conditions=100, length=100, features=23]
        Real A + Real B
    results_critic_scores : float, numpy ndarray, shape=[n_latents=100, n_conditions=100]
        critic scores for each input data
    """

    save_path = os.path.join(COLLECT_PATH, 'rnn')
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    real_data = np.load(FLAGS.data_path)[:, :FLAGS.seq_length, :, :]
    print('real_data.shape', real_data.shape)
    # DataFactory
    data_factory = DataFactory(real_data)
    # target data
    target_data = np.load('../../data/FixedFPS5.npy')[-100:]
    target_length = np.load('../../data/FixedFPS5Length.npy')[-100:]
    print('target_data.shape', target_data.shape)
    team_AB = np.concatenate(
        [
            # ball
            target_data[:, :, 0, :3].reshape(
                [target_data.shape[0], target_data.shape[1], 1 * 3]),
            # team A players
            target_data[:, :, 1:6, :2].reshape(
                [target_data.shape[0], target_data.shape[1], 5 * 2]),
            # team B players
            target_data[:, :, 6:11, :2].reshape(
                [target_data.shape[0], target_data.shape[1], 5 * 2])
        ], axis=-1
    )
    team_AB = data_factory.normalize(team_AB)
    team_A = team_AB[:, :, :13]
    team_B = team_AB[:, :, 13:]
    # result collector
    results_A_fake_B = []
    results_A_real_B = []
    config = TrainingConfig(235)
    with tf.get_default_graph().as_default() as graph:
        # model
        C = C_MODEL(config, graph)
        G = G_MODEL(config, C.inference, graph)
        tfconfig = tf.ConfigProto()
        tfconfig.gpu_options.allow_growth = True
        default_sess = tf.Session(config=tfconfig, graph=graph)
        # saver for later restore
        saver = tf.train.Saver(max_to_keep=0)  # 0 -> keep them all
        # restore model if exist
        saver.restore(default_sess, FLAGS.restore_path)
        print('successfully restore model from checkpoint: %s' %
              (FLAGS.restore_path))
        for idx in range(team_AB.shape[0]):
            # given 100(FLAGS.n_latents) latents generate 100 results on same condition at once
            real_samples = team_B[idx:idx + 1, :]
            real_samples = np.concatenate(
                [real_samples for _ in range(FLAGS.n_latents)], axis=0)
            real_conds = team_A[idx:idx + 1, :]
            real_conds = np.concatenate(
                [real_conds for _ in range(FLAGS.n_latents)], axis=0)
            # generate result
            latents = z_samples(FLAGS.n_latents)
            result = G.generate(default_sess, latents, real_conds)
            # calculate em distance
            recoverd_A_fake_B = data_factory.recover_data(
                np.concatenate([real_conds, result], axis=-1))
            # padding to length=200
            dummy = np.zeros(
                shape=[FLAGS.n_latents, team_AB.shape[1] - target_length[idx], team_AB.shape[2]])
            temp_A_fake_B_concat = np.concatenate(
                [recoverd_A_fake_B[:, :target_length[idx]], dummy], axis=1)
            results_A_fake_B.append(temp_A_fake_B_concat)
    print(np.array(results_A_fake_B).shape)
    # concat along with conditions dimension (axis=1)
    results_A_fake_B = np.stack(results_A_fake_B, axis=1)
    # real data
    results_A = data_factory.recover_BALL_and_A(team_A)
    results_real_B = data_factory.recover_B(team_B)
    results_A_real_B = data_factory.recover_data(team_AB)
    # saved as numpy
    print(np.array(results_A_fake_B).shape)
    print(np.array(results_A_real_B).shape)
    np.save(os.path.join(save_path, 'results_A_fake_B.npy'),
            np.array(results_A_fake_B).astype(np.float32).reshape([FLAGS.n_latents, team_AB.shape[0], team_AB.shape[1], 23]))
    np.save(os.path.join(save_path, 'results_A_real_B.npy'),
            np.array(results_A_real_B).astype(np.float32).reshape([team_AB.shape[0], team_AB.shape[1], 23]))
    print('!!Completely Saved!!')
def mode_5(sess, graph, save_path):
    """ to calculate hueristic score on selected result
    """
    NORMAL_C_ID = [154, 108, 32, 498, 2, 513, 263, 29, 439, 249, 504, 529, 24, 964, 641, 739, 214, 139, 819, 1078, 772, 349, 676, 1016, 582, 678, 39, 279,
                   918, 477, 809, 505, 896, 600, 564, 50, 810, 1132, 683, 578, 1131, 887, 621, 1097, 665, 528, 310, 631, 1102, 6, 945, 1020, 853, 490, 64, 1002, 656]
    NORMAL_N_ID = [58, 5, 47, 66, 79, 21, 70, 54, 3, 59, 67, 59, 84, 38, 71, 62, 55, 86, 14, 83, 94, 97, 83, 27, 38, 68, 95,
                   26, 60, 2, 54, 46, 34, 75, 38, 4, 59, 87, 52, 44, 92, 28, 86, 71, 24, 28, 13, 70, 87, 44, 52, 25, 59, 61, 86, 16, 98]
    GOOD_C_ID = [976, 879, 293, 750, 908, 878, 831, 1038, 486, 268,
                 265, 252, 1143, 383, 956, 974, 199, 777, 585, 34, 932]
    GOOD_N_ID = [52, 16, 87, 43, 45, 66, 22, 77, 36,
                 50, 47, 9, 34, 9, 82, 42, 65, 43, 7, 29, 62]
    BEST_C_ID = [570, 517, 962, 1088, 35, 623, 1081, 33, 255, 571,
                 333, 990, 632, 431, 453, 196, 991, 267, 591, 902, 597, 646]
    BEST_N_ID = [22, 42, 76, 92, 12, 74, 92, 58, 69, 69,
                 23, 63, 89, 7, 74, 27, 12, 20, 35, 77, 62, 63]

    DUMMY_ID = np.zeros(shape=[28])
    ALL_C_ID = np.concatenate(
        [NORMAL_C_ID, GOOD_C_ID, BEST_C_ID, DUMMY_ID]).astype(np.int32)
    ALL_N_ID = np.concatenate(
        [NORMAL_N_ID, GOOD_N_ID, BEST_N_ID, DUMMY_ID]).astype(np.int32)
    print(ALL_C_ID.shape)
    print(ALL_N_ID.shape)
    fake_result_AB = np.load(
        'v3/2/collect/mode_1/results_A_fake_B.npy')[ALL_N_ID, ALL_C_ID]
    real_result_AB = np.load(
        'v3/2/collect/mode_1/results_A_real_B.npy')[ALL_C_ID]
    print(fake_result_AB.shape)
    print(real_result_AB.shape)

    # normalize
    real_data = np.load(FLAGS.data_path)[:, :FLAGS.seq_length, :, :]
    print('real_data.shape', real_data.shape)
    data_factory = DataFactory(real_data)
    fake_result_AB = data_factory.normalize(fake_result_AB)
    real_result_AB = data_factory.normalize(real_result_AB)

    # placeholder tensor
    real_data_t = graph.get_tensor_by_name('real_data:0')
    matched_cond_t = graph.get_tensor_by_name('matched_cond:0')
    # result tensor
    heuristic_penalty_pframe = graph.get_tensor_by_name(
        'Critic/C_inference/heuristic_penalty/Min:0')
    # 'Generator/G_loss/C_inference/linear_result/Reshape:0')

    if not os.path.exists(save_path):
        os.makedirs(save_path)
    # real
    feed_dict = {
        real_data_t: real_result_AB[:, :, 13:23],
        matched_cond_t: real_result_AB[:, :, :13]
    }
    real_hp_pframe = sess.run(heuristic_penalty_pframe, feed_dict=feed_dict)

    # fake
    feed_dict = {
        real_data_t: fake_result_AB[:, :, 13:23],
        matched_cond_t: fake_result_AB[:, :, :13]
    }
    fake_hp_pframe = sess.run(heuristic_penalty_pframe, feed_dict=feed_dict)

    print(np.mean(real_hp_pframe[:100]))
    print(np.mean(fake_hp_pframe[:100]))
    print('!!Completely Saved!!')
예제 #22
0
def main(rank, args):

    dist.init_process_group(backend="nccl",
                            init_method="env://",
                            world_size=args.world_size,
                            rank=rank)

    trainset = DataFactory(name=args.dataset,
                           partition=args.partitions[0],
                           data_root=args.data_root,
                           detection_root=args.train_detection_dir,
                           box_score_thresh_h=args.human_thresh,
                           box_score_thresh_o=args.object_thresh,
                           flip=True)

    valset = DataFactory(name=args.dataset,
                         partition=args.partitions[1],
                         data_root=args.data_root,
                         detection_root=args.val_detection_dir,
                         box_score_thresh_h=args.human_thresh,
                         box_score_thresh_o=args.object_thresh)

    train_loader = DataLoader(dataset=trainset,
                              collate_fn=custom_collate,
                              batch_size=args.batch_size,
                              num_workers=args.num_workers,
                              pin_memory=True,
                              sampler=DistributedSampler(
                                  trainset,
                                  num_replicas=args.world_size,
                                  rank=rank))

    val_loader = DataLoader(dataset=valset,
                            collate_fn=custom_collate,
                            batch_size=args.batch_size,
                            num_workers=args.num_workers,
                            pin_memory=True,
                            sampler=DistributedSampler(
                                valset,
                                num_replicas=args.world_size,
                                rank=rank))

    # Fix random seed for model synchronisation
    torch.manual_seed(args.random_seed)

    if args.dataset == 'hicodet':
        object_to_target = train_loader.dataset.dataset.object_to_verb
        human_idx = 49
        num_classes = 117
    elif args.dataset == 'vcoco':
        object_to_target = train_loader.dataset.dataset.object_to_action
        human_idx = 1
        num_classes = 24
    net = SpatioAttentiveGraph(object_to_target,
                               human_idx,
                               num_iterations=args.num_iter,
                               postprocess=False,
                               max_human=args.max_human,
                               max_object=args.max_object)
    # Fix backbone parameters
    for p in net.backbone.parameters():
        p.requires_grad = False

    if os.path.exists(args.checkpoint_path):
        print("=> Rank {}: continue from saved checkpoint".format(rank),
              args.checkpoint_path)
        checkpoint = torch.load(args.checkpoint_path, map_location='cpu')
        net.load_state_dict(checkpoint['model_state_dict'])
        optim_state_dict = checkpoint['optim_state_dict']
        sched_state_dict = checkpoint['scheduler_state_dict']
        epoch = checkpoint['epoch']
        iteration = checkpoint['iteration']
    else:
        print(
            "=> Rank {}: start from a randomly initialised model".format(rank))
        optim_state_dict = None
        sched_state_dict = None
        epoch = 0
        iteration = 0

    engine = CustomisedDLE(net,
                           train_loader,
                           val_loader,
                           num_classes=num_classes,
                           optim_params={
                               'lr': args.learning_rate,
                               'momentum': args.momentum,
                               'weight_decay': args.weight_decay
                           },
                           optim_state_dict=optim_state_dict,
                           lr_scheduler=True,
                           lr_sched_params={
                               'milestones': args.milestones,
                               'gamma': args.lr_decay
                           },
                           print_interval=args.print_interval,
                           cache_dir=args.cache_dir)
    engine.update_state_key(epoch=epoch, iteration=iteration)
    if sched_state_dict is not None:
        engine.fetch_state_key('lr_scheduler').load_state_dict(
            sched_state_dict)

    engine(args.num_epochs)
예제 #23
0
def mode_9(sess, graph, save_path, is_valid=FLAGS.is_valid):
    """ to collect results vary in length
    Saved Result
    ------------
    results_A_fake_B : float, numpy ndarray, shape=[n_latents=100, n_conditions=100, length=100, features=23]
        Real A + Fake B
    results_A_real_B : float, numpy ndarray, shape=[n_conditions=100, length=100, features=23]
        Real A + Real B
    results_critic_scores : float, numpy ndarray, shape=[n_latents=100, n_conditions=100]
        critic scores for each input data
    """
    # placeholder tensor
    latent_input_t = graph.get_tensor_by_name('Generator/latent_input:0')
    team_a_t = graph.get_tensor_by_name('Generator/team_a:0')
    G_samples_t = graph.get_tensor_by_name('Critic/G_samples:0')
    matched_cond_t = graph.get_tensor_by_name('Critic/matched_cond:0')
    # result tensor
    result_t = graph.get_tensor_by_name(
        'Generator/G_inference/conv_result/conv1d/Maximum:0')
    critic_scores_t = graph.get_tensor_by_name(
        'Critic/C_inference_1/conv_output/Reshape:0')

    if not os.path.exists(save_path):
        os.makedirs(save_path)

    real_data = np.load(FLAGS.data_path)[:, :FLAGS.seq_length, :, :]
    print('real_data.shape', real_data.shape)
    # DataFactory
    data_factory = DataFactory(real_data)
    # target data
    target_data = np.load('../../data/FixedFPS5.npy')[-100:]
    target_length = np.load('../../data/FixedFPS5Length.npy')[-100:]
    print('target_data.shape', target_data.shape)
    team_AB = np.concatenate(
        [
            # ball
            target_data[:, :, 0, :3].reshape(
                [target_data.shape[0], target_data.shape[1], 1 * 3]),
            # team A players
            target_data[:, :, 1:6, :2].reshape(
                [target_data.shape[0], target_data.shape[1], 5 * 2]),
            # team B players
            target_data[:, :, 6:11, :2].reshape(
                [target_data.shape[0], target_data.shape[1], 5 * 2])
        ],
        axis=-1)
    team_AB = data_factory.normalize(team_AB)
    team_A = team_AB[:, :, :13]
    team_B = team_AB[:, :, 13:]
    # result collector
    results_A_fake_B = []
    results_A_real_B = []
    results_critic_scores = []

    for idx in range(team_AB.shape[0]):
        # given 100(FLAGS.n_latents) latents generate 100 results on same condition at once
        real_samples = team_B[idx:idx + 1, :target_length[idx]]
        real_samples = np.concatenate(
            [real_samples for _ in range(FLAGS.n_latents)], axis=0)
        real_conds = team_A[idx:idx + 1, :target_length[idx]]
        real_conds = np.concatenate(
            [real_conds for _ in range(FLAGS.n_latents)], axis=0)
        # generate result
        latents = z_samples(FLAGS.n_latents)
        feed_dict = {latent_input_t: latents, team_a_t: real_conds}
        result = sess.run(result_t, feed_dict=feed_dict)
        # calculate em distance
        feed_dict = {G_samples_t: result, matched_cond_t: real_conds}
        em_dist = sess.run(critic_scores_t, feed_dict=feed_dict)
        recoverd_A_fake_B = data_factory.recover_data(
            np.concatenate([real_conds, result], axis=-1))
        # padding to length=200
        dummy = np.zeros(shape=[
            FLAGS.n_latents, team_AB.shape[1] -
            target_length[idx], team_AB.shape[2]
        ])
        temp_A_fake_B_concat = np.concatenate([recoverd_A_fake_B, dummy],
                                              axis=1)
        results_A_fake_B.append(temp_A_fake_B_concat)
        results_critic_scores.append(em_dist)
    print(np.array(results_A_fake_B).shape)
    print(np.array(results_critic_scores).shape)
    # concat along with conditions dimension (axis=1)
    results_A_fake_B = np.stack(results_A_fake_B, axis=1)
    results_critic_scores = np.stack(results_critic_scores, axis=1)
    # real data
    results_A = data_factory.recover_BALL_and_A(team_A)
    results_real_B = data_factory.recover_B(team_B)
    results_A_real_B = data_factory.recover_data(team_AB)
    # saved as numpy
    print(np.array(results_A_fake_B).shape)
    print(np.array(results_A_real_B).shape)
    print(np.array(results_critic_scores).shape)
    np.save(
        save_path + 'results_A_fake_B.npy',
        np.array(results_A_fake_B).astype(np.float32).reshape(
            [FLAGS.n_latents, team_AB.shape[0], team_AB.shape[1], 23]))
    np.save(
        save_path + 'results_A_real_B.npy',
        np.array(results_A_real_B).astype(np.float32).reshape(
            [team_AB.shape[0], team_AB.shape[1], 23]))
    np.save(
        save_path + 'results_critic_scores.npy',
        np.array(results_critic_scores).astype(np.float32).reshape(
            [FLAGS.n_latents, team_AB.shape[0]]))
    print('!!Completely Saved!!')
def mode_1(sess, graph, save_path, is_valid=FLAGS.is_valid):
    """ to collect results 
    Saved Result
    ------------
    results_A_fake_B : float, numpy ndarray, shape=[n_latents=100, n_conditions=128*9, length=100, features=23]
        Real A + Fake B
    results_A_real_B : float, numpy ndarray, shape=[n_conditions=128*9, length=100, features=23]
        Real A + Real B
    results_critic_scores : float, numpy ndarray, shape=[n_latents=100, n_conditions=128*9]
        critic scores for each input data
    """
    # placeholder tensor
    latent_input_t = graph.get_tensor_by_name('latent_input:0')
    team_a_t = graph.get_tensor_by_name('team_a:0')
    G_samples_t = graph.get_tensor_by_name('G_samples:0')
    matched_cond_t = graph.get_tensor_by_name('matched_cond:0')
    # result tensor
    result_t = graph.get_tensor_by_name(
        'Generator/G_inference/conv_result/conv1d/Maximum:0')
    # critic_scores_t = graph.get_tensor_by_name(
    #     'Critic/C_inference_1/conv_output/Reshape:0')
    critic_scores_t = graph.get_tensor_by_name(
        'Critic/C_inference_1/linear_result/BiasAdd:0')

    if not os.path.exists(save_path):
        os.makedirs(save_path)

    real_data = np.load(FLAGS.data_path)[:, :FLAGS.seq_length, :, :]
    print('real_data.shape', real_data.shape)
    # normalize
    data_factory = DataFactory(real_data)
    # result collector
    results_A_fake_B = []
    results_A_real_B = []
    results_critic_scores = []

    # shuffle the data
    train_data, valid_data = data_factory.fetch_data()
    if is_valid:
        target_data = valid_data
    else:
        target_data = train_data
    target_data = np.load('ADD-100.npy')
    team_AB = np.concatenate(
        [
            # ball
            target_data[:, :, 0, :3].reshape(
                [target_data.shape[0], target_data.shape[1], 1 * 3]),
            # team A players
            target_data[:, :, 1:6, :2].reshape(
                [target_data.shape[0], target_data.shape[1], 5 * 2]),
            # team B players
            target_data[:, :, 6:11, :2].reshape(
                [target_data.shape[0], target_data.shape[1], 5 * 2])
        ], axis=-1
    )
    team_AB = data_factory.normalize(team_AB)
    print(team_AB.shape)
    dummy_AB = np.zeros(shape=[98, 100, 23])
    team_AB = np.concatenate([team_AB, dummy_AB], axis=0)
    team_A = team_AB[:, :, :13]
    team_B = team_AB[:, :, 13:]

    # for idx in range(0, FLAGS.batch_size, FLAGS.batch_size):
    real_samples = team_B
    real_conds = team_A
    # generate result
    temp_critic_scores = []
    temp_A_fake_B = []
    for i in range(FLAGS.n_latents):
        latents = z_samples(FLAGS.batch_size)
        feed_dict = {
            latent_input_t: latents,
            team_a_t: real_conds
        }
        result = sess.run(
            result_t, feed_dict=feed_dict)
        feed_dict = {
            G_samples_t: result,
            matched_cond_t: real_conds
        }
        critic_scores = sess.run(
            critic_scores_t, feed_dict=feed_dict)
        temp_A_fake_B.append(data_factory.recover_data(
            np.concatenate([real_conds, result], axis=-1)))
        temp_critic_scores.append(critic_scores)
    results_A_fake_B.append(temp_A_fake_B)
    results_critic_scores.append(temp_critic_scores)
    # concat along with conditions dimension (axis=1)
    results_A_fake_B = np.concatenate(results_A_fake_B, axis=1)
    results_critic_scores = np.concatenate(results_critic_scores, axis=1)
    results_A = data_factory.recover_BALL_and_A(
        real_conds)
    results_real_B = data_factory.recover_B(
        real_samples)
    results_A_real_B = np.concatenate([results_A, results_real_B], axis=-1)
    # saved as numpy
    print(np.array(results_A_fake_B).shape)
    print(np.array(results_A_real_B).shape)
    print(np.array(results_critic_scores).shape)
    np.save(save_path + 'results_A_fake_B.npy',
            np.array(results_A_fake_B)[:, :30].astype(np.float32).reshape([FLAGS.n_latents, 30, FLAGS.seq_length, 23]))
    np.save(save_path + 'results_A_real_B.npy',
            np.array(results_A_real_B)[:30].astype(np.float32).reshape([30, FLAGS.seq_length, 23]))
    np.save(save_path + 'results_critic_scores.npy',
            np.array(results_critic_scores)[:, :30].astype(np.float32).reshape([FLAGS.n_latents, 30]))
    print('!!Completely Saved!!')
예제 #25
0
def main(args):

    torch.cuda.set_device(0)
    torch.manual_seed(args.random_seed)
    torch.backends.cudnn.benchmark = False

    train_loader = DataLoader(dataset=DataFactory(
        name=args.dataset,
        partition=args.partitions[0],
        data_root=args.data_root,
        detection_root=args.train_detection_dir,
        box_score_thresh_h=args.human_thresh,
        box_score_thresh_o=args.object_thresh,
        flip=True),
                              collate_fn=custom_collate,
                              batch_size=args.batch_size,
                              num_workers=args.num_workers,
                              pin_memory=True,
                              shuffle=True)

    val_loader = DataLoader(dataset=DataFactory(
        name=args.dataset,
        partition=args.partitions[1],
        data_root=args.data_root,
        detection_root=args.val_detection_dir,
        box_score_thresh_h=args.human_thresh,
        box_score_thresh_o=args.object_thresh),
                            collate_fn=custom_collate,
                            batch_size=args.batch_size,
                            num_workers=args.num_workers,
                            pin_memory=True)

    if args.dataset == 'hicodet':
        object_to_target = train_loader.dataset.dataset.object_to_verb
        human_idx = 49
        num_classes = 117
    elif args.dataset == 'vcoco':
        object_to_target = train_loader.dataset.dataset.object_to_action
        human_idx = 1
        num_classes = 24
    net = SpatioAttentiveGraph(object_to_target,
                               human_idx,
                               num_iterations=args.num_iter,
                               postprocess=False,
                               max_human=args.max_human,
                               max_object=args.max_object)
    # Fix backbone parameters
    for p in net.backbone.parameters():
        p.requires_grad = False

    if os.path.exists(args.checkpoint_path):
        print("Continue from saved checkpoint ", args.checkpoint_path)
        checkpoint = torch.load(args.checkpoint_path, map_location='cpu')
        net.load_state_dict(checkpoint['model_state_dict'])
        optim_state_dict = checkpoint['optim_state_dict']
        sched_state_dict = checkpoint['scheduler_state_dict']
        epoch = checkpoint['epoch']
        iteration = checkpoint['iteration']
    else:
        print("Start from a randomly intialised model")
        optim_state_dict = None
        sched_state_dict = None
        epoch = 0
        iteration = 0

    engine = CustomisedLE(net,
                          train_loader,
                          val_loader,
                          num_classes=num_classes,
                          optim_params={
                              'lr': args.learning_rate,
                              'momentum': args.momentum,
                              'weight_decay': args.weight_decay
                          },
                          optim_state_dict=optim_state_dict,
                          lr_scheduler=True,
                          lr_sched_params={
                              'milestones': args.milestones,
                              'gamma': args.lr_decay
                          },
                          print_interval=args.print_interval,
                          cache_dir=args.cache_dir)
    engine.update_state_key(epoch=epoch, iteration=iteration)
    if sched_state_dict is not None:
        engine.fetch_state_key('lr_scheduler').load_state_dict(
            sched_state_dict)

    engine(args.num_epochs)
def mode_4(sess, graph, save_path, is_valid=FLAGS.is_valid):
    """ to analize code, only change first dimension for comparison
    Saved Result
    ------------
    results_A_fake_B : float, numpy ndarray, shape=[n_latents=11, n_conditions=128*9, length=100, features=23]
        Real A + Fake B
    results_A_real_B : float, numpy ndarray, shape=[n_latents=11, n_conditions=128*9, length=100, features=23]
        Real A + Real B
    results_critic_scores : float, numpy ndarray, shape=[n_latents=11, n_conditions=128*9]
        critic scores for each input data
    """
    target_dims = 0
    n_latents = 11

    # placeholder tensor
    latent_input_t = graph.get_tensor_by_name('latent_input:0')
    team_a_t = graph.get_tensor_by_name('team_a:0')
    G_samples_t = graph.get_tensor_by_name('G_samples:0')
    matched_cond_t = graph.get_tensor_by_name('matched_cond:0')
    # result tensor
    result_t = graph.get_tensor_by_name(
        'Generator/G_inference/conv_result/conv1d/Maximum:0')
    critic_scores_t = graph.get_tensor_by_name(
        'Critic/C_inference_1/linear_result/BiasAdd:0')
    # 'Generator/G_loss/C_inference/linear_result/Reshape:0')

    if not os.path.exists(save_path):
        os.makedirs(save_path)

    real_data = np.load(FLAGS.data_path)[:, :FLAGS.seq_length, :, :]
    print('real_data.shape', real_data.shape)
    # normalize
    data_factory = DataFactory(real_data)
    # result collector
    results_A_fake_B = []
    results_A_real_B = []
    results_critic_scores = []

    # shuffle the data
    train_data, valid_data = data_factory.fetch_data()
    if is_valid:
        target_data = valid_data
    else:
        target_data = train_data
    latents = z_samples(FLAGS.batch_size)
    for idx in range(0, FLAGS.n_conditions, FLAGS.batch_size):
        real_samples = target_data['B'][idx:idx + FLAGS.batch_size]
        real_conds = target_data['A'][idx:idx + FLAGS.batch_size]
        # generate result
        temp_critic_scores = []
        temp_A_fake_B = []
        for i in range(n_latents):
            latents[:, target_dims] = -2.5 + 0.5 * i
            feed_dict = {
                latent_input_t: latents,
                team_a_t: real_conds
            }
            result = sess.run(
                result_t, feed_dict=feed_dict)
            feed_dict = {
                G_samples_t: result,
                matched_cond_t: real_conds
            }
            critic_scores = sess.run(
                critic_scores_t, feed_dict=feed_dict)
            temp_A_fake_B.append(data_factory.recover_data(
                np.concatenate([real_conds, result], axis=-1)))
            temp_critic_scores.append(critic_scores)
        results_A_fake_B.append(temp_A_fake_B)
        results_critic_scores.append(temp_critic_scores)
    # concat along with conditions dimension (axis=1)
    results_A_fake_B = np.concatenate(results_A_fake_B, axis=1)
    results_critic_scores = np.concatenate(results_critic_scores, axis=1)
    results_A = data_factory.recover_BALL_and_A(
        target_data['A'][:FLAGS.n_conditions])
    results_real_B = data_factory.recover_B(
        target_data['B'][:FLAGS.n_conditions])
    results_A_real_B = np.concatenate([results_A, results_real_B], axis=-1)
    # saved as numpy
    print(np.array(results_A_fake_B).shape)
    print(np.array(results_A_real_B).shape)
    print(np.array(results_critic_scores).shape)
    np.save(save_path + 'results_A_fake_B.npy',
            np.array(results_A_fake_B).astype(np.float32).reshape([n_latents, FLAGS.n_conditions, FLAGS.seq_length, 23]))
    np.save(save_path + 'results_A_real_B.npy',
            np.array(results_A_real_B).astype(np.float32).reshape([FLAGS.n_conditions, FLAGS.seq_length, 23]))
    np.save(save_path + 'results_critic_scores.npy',
            np.array(results_critic_scores).astype(np.float32).reshape([n_latents, FLAGS.n_conditions]))
    print('!!Completely Saved!!')