예제 #1
0
def main(_):
    print('loading data...')
    tr_revs, tr_revs_content = read_data_file(FLAGS.train_file_path)
    word_idx_map, w2v = creat_word_embedding(tr_revs_content,
                                             FLAGS.embedding_dim)

    tr_data = data_parse(tr_revs, word_idx_map, FLAGS.max_sentence_len)

    te_revs, _ = read_data_file(FLAGS.test_file_path)
    te_data = data_parse(te_revs, word_idx_map, FLAGS.max_sentence_len)

    tc_lstm = TC_LSTM(
        n_hidden=FLAGS.n_hidden,
        n_class=FLAGS.n_class,
        max_sentence_len=FLAGS.max_sentence_len,
        l2_reg=FLAGS.l2_reg,
    )
    print('start training...')
    tc_lstm.learn(word_idx_map,
                  w2v,
                  tr_data,
                  te_data,
                  n_iter=FLAGS.n_iter,
                  batch_size=FLAGS.batch_size,
                  learning_rate=FLAGS.learning_rate)
예제 #2
0
def load_data(config: Config, transformer: TokensToNumbers, type):
    lines = read_data_file(config.get_data_path(type))
    tokens = pre_process(
        lines, config.baseline_config.embedding_config.pre_processing)
    return transformer.transform(
        config.baseline_config.embedding_config.input_type, tokens,
        config.max_sequence_length)
예제 #3
0
 def read_data_restart(self, x=-1, y=-1):
     # print("mydict: {}".format(mydict))
     # self.inventory = Inventory("player")
     # self.inventory.read_data()
     user_data = utils.get_user_data()
     path = os.path.join("data", user_data["character_name"], constants.PLAYER_DATA_FILE)
     mylist = utils.read_data_file(path, num_of_fields=11)
     mydict = mylist[0]
     # ----
     if x == -1 and y == -1:
         self.x = mydict["x"]
         self.y = mydict["y"]
     else:
         if x == -1 or y == -1:
             raise ValueError("Error!")
         self.x = x
         self.y = y
     # ----
     self.name = mydict["name"]
     self.kind = mydict["kind"]
     if utils.is_int(mydict["direction"]) == True:
         self.direction = -90
     else:
         self.direction = utils.convert_direction_to_integer(mydict["direction"])
     self.max_hit_points = mydict["max_hit_points"]
     self.hit_points = mydict["hit_points"]
     self.chance_to_hit = mydict["chance_to_hit"]
     self.experience = mydict["experience"]
     self.profession = mydict["profession"]
     self.gold = mydict["gold"]
     # ----
     self.load_images()
     # ----
     self.direction = "DOWN"
    def create_file(self, row_count=None):
        """
        Creates a new large test file, passing in a row count to 
        set a maximum number of rows. 
        """
        self.ts_df, self.lar_df = utils.read_data_file(path=self.source_filepath, 
            data_file=self.source_filename)

        #Stores the second column of the TS data as "bank_name." 
        self.bank_name = self.ts_df.iloc[0][1]
        
        #Stores the fifteenth column of the TS data as "lei."    
        self.lei = self.ts_df.iloc[0][14]

        #Changes the TS row to the number of rows specified. 
        self.ts_df['lar_entries'] = str(row_count)

        #Changes each LAR row to the LEI specified in the TS row. 
        self.lar_df["lei"] = self.lei

        #Creates a dataframe of LAR with the number of rows specified.
        new_lar_df = self.new_lar_rows(row_count=row_count)

        #Writes file to the output filename and path. 
        utils.write_file(path=self.output_filepath, ts_input=self.ts_df, 
            lar_input=new_lar_df, name=self.output_filename)
        
        #Prints out a statement with the number of rows created, 
        #and the location of the new file.
        statement = (str("{:,}".format(row_count)) + 
            " Row File Created for " + str(self.bank_name) + 
            " File Path: " + str(self.output_filepath+self.output_filename))
        
        print(statement)
예제 #5
0
 def read_data(self):
     filepath = os.path.join("data", "zones", self.zone_name, "zone_init.txt")
     mylist = utils.read_data_file(filepath, 4)
     mydict = mylist[0]
     self.zone_description = mydict["zone_description"]
     self.obstacles.read_data()
     self.walkables.read_data()
예제 #6
0
 def read_data_first(self):
     # self.inventory.read_data()
     user_data = utils.get_user_data()
     path = os.path.join("data", user_data["character_name"], constants.PLAYER_DATA_FILE)
     # print("path: {}".format(path))
     mylist = utils.read_data_file(path, num_of_fields=11)
     mydict = mylist[0]
     # print("mydict: {}".format(mydict))
     # ----
     self.x = mydict["x"]
     self.y = mydict["y"]
     self.name = mydict["name"]
     self.kind = mydict["kind"]
     if utils.is_int(mydict["direction"]) == True:
         self.direction = int(mydict["direction"])
     else:
         self.direction = utils.convert_direction_to_integer(mydict["direction"])
     self.max_hit_points = mydict["max_hit_points"]
     self.hit_points = mydict["hit_points"]
     self.chance_to_hit = mydict["chance_to_hit"]
     self.experience = mydict["experience"]
     self.profession = mydict["profession"]
     self.gold = mydict["gold"]
     # ----
     self.load_images()
    def __init__(self, passes_all_filepath, passes_all_filename,
                 fails_all_filepath, fails_all_filename):
        """The class instantiates with a filepath and name to a file that 
		passes every edit, and a filepath and name to a file that fails 
		a an edit or multiple edits for every row."""

        #To create a file that has a certain number of clean rows and failed rows,
        # two sets of TS and LAR data are loaded to create a file that
        # has some rows that fail edits, and some rows that do not.
        # The constructor provides a way for users to input their files that pass cleanly
        # and files that fail an edit or multiple edits for every row.

        #Loading TS and LAR data for the clean file.
        self.clean_ts, self.clean_lar = utils.read_data_file(
            path=passes_all_filepath, data_file=passes_all_filename)

        #Loading TS and LAR data for the file that fails a set of edits for every row.
        self.fail_ts, self.fail_lar = utils.read_data_file(
            path=fails_all_filepath, data_file=fails_all_filename)

        #Storing the first LAR row of the clean file.
        self.clean_lar = self.clean_lar.iloc[0:1]
        #print(self.clean_lar)

        #Storing the LEI in a global variable.
        self.lei = self.clean_ts.iloc[0][14]

        #Storing the bank name as a global variable.
        self.bank_name = self.clean_ts.iloc[0][1]

        #Setting LEI's for each by self.lei.
        self.clean_lar["lei"] = self.lei

        self.fail_lar["lei"] = self.lei

        #Storing the first LAR row of the file that fails a set of edits for every row.
        self.fail_lar = self.fail_lar.iloc[0:1]
        #print(self.fail_lar)

        #Prints a statement indicating that the object has been instantiated.
        print("Instantiated EditFailsbyRow Class")
예제 #8
0
def prepare_nbhd_map(dest_city_name=None, X_match_df=None):
    file_name = [i for i in LIST_CITY_DATA_FILE_NAME if dest_city_name in i][0]
    X_dest_raw = read_data_file(file_name=file_name, data_type='raw')
    # map1
    plot_df_map_nbhd = X_dest_raw[X_dest_raw[col_grain].isin(
        X_match_df[col_grain])].copy()
    plot_df_map_nbhd = plot_df_map_nbhd.drop_duplicates(subset=[col_grain])
    map_graph_nbhd = plot_nbhd_on_map(plot_df=plot_df_map_nbhd,
                                      marker_size=20,
                                      map_zoom=11,
                                      show_plot=False)
    return map_graph_nbhd
예제 #9
0
    def edit_report(self):
        """
		This function takes in a filepath and name, producing a report on
		whether any rows of the data failed syntax, validity or quality edits.

		The report contains among its fields the edit name, the status of the 
		edit, the number of rows failed by the edit, if any, and the ULI's or 
		NULIs (loan ID) of the rows that fail the edit. 

		The resulting report is saved as a csv file using configurations
		from the test_filepaths.yaml file. 
		"""

        #Instantiates the rules engine class as a checker object with a
        #LAR schema, a TS schema, and geographic geographic data.
        checker = rules_engine(lar_schema=self.lar_schema_df,
                               ts_schema=self.ts_schema_df,
                               geographic_data=self.geographic_data)

        #Seperates data from the filepath and filename into a TS dataframe
        #and a LAR dataframe.
        ts_df, lar_df = utils.read_data_file(
            path=self.edit_report_config['data_filepath'],
            data_file=self.edit_report_config['data_filename'])

        #Loads the TS and LAR dataframes into the checker object.
        checker.load_data_frames(ts_df, lar_df)

        #Applies each function in the rules engine that checks for edits
        #and creates a results list of edits failed or passed.
        for func in dir(checker):
            if func[:1] in ("s", "v", "q") and func[1:4].isdigit() == True:
                getattr(checker, func)()

        #Creates a dataframe of results from the checker.
        report_df = pd.DataFrame(checker.results)

        #Writes the report to the filepath and name designated in
        #the test_fielpaths yaml
        edit_report_path = self.edit_report_config[
            'edit_report_output_filepath']

        if not os.path.exists(edit_report_path):
            os.makedirs(edit_report_path)

        report_df.to_csv(
            edit_report_path +
            self.edit_report_config['edit_report_output_filename'])

        #Logs the result.
        logging.info("Edit Report has been created in {filepath}".format(
            filepath=edit_report_path))
예제 #10
0
def main(_):
    print('loading data...')
    tr_revs, tr_revs_content = read_data_file(FLAGS.train_file_path)
    word_idx_map, w2v = creat_word_embedding(tr_revs_content,
                                             FLAGS.embedding_dim)

    tr_x, _, tr_y = data_parse_one_direction(tr_revs, word_idx_map,
                                             FLAGS.max_sentence_len)

    te_revs, _ = read_data_file(FLAGS.test_file_path)
    te_x, _, te_y = data_parse_one_direction(te_revs, word_idx_map,
                                             FLAGS.max_sentence_len)

    cnn = Text_CNN(max_len=2 * FLAGS.max_sentence_len + 1,
                   n_classes=FLAGS.n_classes)

    print('start training...')
    cnn.learn(word_idx_map,
              w2v, (tr_x, tr_y), (te_x, te_y),
              n_iters=FLAGS.n_iters,
              batch_size=FLAGS.batch_size,
              learning_rate=FLAGS.learning_rate)
예제 #11
0
def load_data_for_conssed(config: Config, transformer: TokensToNumbers, type):
    lines = read_data_file(config.get_data_path(type))
    data = []
    for part_config in [
            config.conssed_config.semantic_part,
            config.conssed_config.sentiment_part
    ]:
        tokens = pre_process(lines,
                             part_config.embedding_config.pre_processing)
        data.append(
            transformer.transform(part_config.embedding_config.input_type,
                                  tokens, config.max_sequence_length))

    sem_data, sen_data = data
    return sem_data, sen_data
예제 #12
0
 def read_data(self, name_and_location):
     # filepath = os.path.join("data", constants.MONSTERS_DATA_FILE)
     # number_of_fields = 8
     # mylist = utils.read_data_file(filepath, number_of_fields)
     # mydict = mylist[0]
     filename = "{}.txt".format(name_and_location[0].strip())
     player_data = utils.get_user_data()
     filepath = os.path.join("data", player_data["character_name"], "monsters", filename)
     number_of_fields = 11
     mylist = utils.read_data_file(filepath, number_of_fields)
     mydict = mylist[0]
     # ----
     # self.x = mydict["x"]
     # self.y = mydict["y"]
     self.x = name_and_location[1]
     self.y = name_and_location[2]
     # self.name = mydict["name"]
     if mydict["name"].lower() != name_and_location[0].lower().strip():
         raise ValueError("Error! mydict[name]: {}; name_and_location[0].strip(): {}".format(mydict["name"], name_and_location[0].strip()))
     self.name = mydict["name"].lower()
     self.kind = mydict["kind"]
     self.maximum_damage = mydict["maximum_damage"]
     self.max_hit_points = mydict["max_hit_points"]
     self.hit_points = mydict["hit_points"]
     self.chance_to_hit = mydict["chance_to_hit"]
     self.experience = mydict["experience"]
     # ---------------------------------------------
     self.monster_image = mydict["monster_image"]
     self.monster_image_dead = mydict["monster_image_dead"]
     # ---------------------------------------------
     filepath = os.path.join("data", "images", self.monster_image)
     try:
         self.image = pygame.image.load(filepath).convert_alpha()
     except:
         s = "Couldn't open: {}".format(filepath)
         raise ValueError(s)
     self.image = pygame.transform.scale(self.image, (constants.TILESIZE, constants.TILESIZE))
     self.rect = self.image.get_rect()
     self.rect = self.rect.move(self.x * constants.TILESIZE, self.y * constants.TILESIZE)
     # ---------------------------------------------
     filepath = os.path.join("data", "images", self.monster_image_dead)
     print("reading filepath: {}".format(filepath))
     self.image_dead_monster = pygame.image.load(filepath).convert_alpha()
     self.image_dead_monster = pygame.transform.scale(self.image_dead_monster, (constants.TILESIZE, constants.TILESIZE))
     # ---------------------------------------------
     if self.hit_points <= 0:
         self.image = self.image_dead_monster
예제 #13
0
    def __init__(self, source_filepath, source_filename):

        #Using the read_data_file method to initialize TS and LAR data.
        self.ts_data, self.lar_data = utils.read_data_file(
            path=source_filepath, data_file=source_filename)

        #Storing the path to the file parts directory to store
        #lar rows that fail a Macro Quality Edit.
        self.file_parts_path = "../edits_files/fileparts/"

        #Storing the path to the quality edits directory to store the
        #resulting file.
        self.quality_edits_path = "../edits_files/quality/"

        #Storing the bank name from the Transmittal Sheet to
        #be used in the file names.
        self.bank_name = self.ts_data.iloc[0][1]
예제 #14
0
 def read_data(self):
     filepath = os.path.join("data", "zones", self.zone_name, "zone_init.txt")
     file_list = utils.read_data_file(filepath, 4)
     mydict = file_list[0]
     self.zone_description = mydict["zone_description"]
     # -------------------
     # # ---- MERCHANTS ----
     # if mydict["merchants"] == None:
     #     pass
     # elif mydict["merchants"] == "none":
     #     self.merchants = None
     # else:
     #     self.merchants = mydict["merchants"].split(";")
     #     self.merchants = [i.strip() for i in self.merchants if len(i.strip()) > 0]
     # # ----
     # big_list = []
     # for a_merchant in self.merchants:
     #     # print("a_merchant: {}".format(a_merchant))
     #     mylist = a_merchant.split(" ")
     #     mylist = (mylist[0], int(mylist[1]), int(mylist[2]))
     #     big_list.append(mylist)
     # print("big_list: {}".format(big_list))
     # # self.merchants = Npcs(["laura", "alvin"])
     # self.merchants = Npcs()
     # self.merchants.read_data(big_list)
     # ------------------
     # ---- MONSTERS ----
     # print("mydict: {}".format(mydict))
     temp_list = []
     if mydict["monsters"] == "none":
         self.monsters = None
     else:
         temp_list = mydict["monsters"].split(";")
         temp_list = [i.strip() for i in temp_list if len(i.strip()) > 0]
     # ----
     print("temp_list: {}".format(temp_list))
     big_list = []
     for a_monster in temp_list:
         print("a_monster: {}".format(a_monster))
         mylist = a_monster.split(" ")
         mylist = (mylist[0], int(mylist[1]), int(mylist[2]))
         big_list.append(mylist)
     print("big_list MONSTERS: {}".format(big_list))
     # self.merchants = Npcs(["laura", "alvin"])
     self.monsters = Monsters()
     self.monsters.read_data(big_list)
예제 #15
0
    def __init__(
            self,
            dbg_mode=False,
            dbg_size=32,
            dbg_image_label_dict='./output/classifier/shapes-redcolor/explainer_input/list_attr_3_5000.txt',
            dbg_img_indices=[]):
        self.input_size = 64
        shapes_dir = os.path.join('data', 'shapes')
        self.dbg_mode = dbg_mode
        dataset = h5py.File(os.path.join(shapes_dir, '3dshapes.h5'), 'r')
        if self.dbg_mode:
            print(
                'Debug mode activated. #{} samples from the shapes dataset will be considered.'
                .format(dbg_size))
            if len(dbg_img_indices) == 0:
                _, file_names_dict = read_data_file(dbg_image_label_dict)
                _tmp_list = list(file_names_dict.keys())[:dbg_size]
            else:
                _tmp_list = dbg_img_indices[:dbg_size]
            self.tmp_list = list(np.sort([int(ind) for ind in _tmp_list]))
            self.images = np.array(dataset['images'][self.tmp_list])
        else:
            self.images = np.array(
                dataset['images']
            )  # array shape [480000, 64, 64, 3], uint8 in range(256)
        self.images = self.images / 255.0
        self.images = self.images - 0.5
        self.images = self.images * 2.0
        self.attributes = np.array(dataset['labels'])
        self._image_shape = self.images.shape[1:]  # [64, 64, 3]
        self._label_shape = self.attributes.shape[1:]  # [6]
        self._n_samples = self.attributes.shape[
            0]  # 10 * 10 * 10 * 8 * 4 * 15 = 480000

        self._FACTORS_IN_ORDER = [
            'floor_hue', 'wall_hue', 'object_hue', 'scale', 'shape',
            'orientation'
        ]
        self._NUM_VALUES_PER_FACTOR = {
            'floor_hue': 10,
            'wall_hue': 10,
            'object_hue': 10,
            'scale': 8,
            'shape': 4,
            'orientation': 15
        }
예제 #16
0
def validate_state_codes(path, lar_file):
    """Parses through an existing test file and replaces the state code 
	abbreviation with one that maps to the FIPS state code indicated 
	in the census tract field."""

    #Seperates LAR and TS Data.
    ts_data, lar_data = utils.read_data_file(path=path, data_file=test_file)

    print(len(lar_data.columns))

    if len(lar_data.columns) != 110:
        print("Not the right number of columns")
        print("Number of columns is " + str(len(lar_data.columns)))

    #Loads the geographic file configuration.
    with open('configurations/geographic_data.yaml') as f:
        # Uses safe_load instead of load.
        geographic = yaml.safe_load(f)

    #Iterrates through each row in the LAR data.
    for index, row in lar_data.iterrows():
        print(row.state)
        if row.state != 'NA':
            #Assigns the state code abbreviation key from the FIPS state code value to the row's
            #state code field.
            print("The census tract is " + row['tract'])
            print("Original value is " + row['state'])
            print("New state is " +
                  geographic['state_FIPS_to_abbreviation'][row['tract'][0:2]])
            row['state'] = geographic['state_FIPS_to_abbreviation'][
                row['tract'][0:2]]
            print("Changed value is " + row['state'])

    #Prints a statement when the process is complete.
    print("Validating State Code Abbreviations")

    #Writes file back to the original path.
    utils.write_file(path=path,
                     ts_input=ts_data,
                     lar_input=lar_data,
                     name=test_file)

    #Prints a statement when the file is re-written.
    print("File rewritten to " + path)
def visualize_train_trajectory(meta_info, data_path):
    _id, floor, fn = Path(data_path).parts[-3:]

    train_floor_data = utils.read_data_file(data_path)
    data_folder = Path(data_path).parents[3]
    floor_folder = data_folder / f"metadata/{_id}/{floor}"
    floor_info_path = floor_folder / "floor_info.json"
    floor_image_path = floor_folder / "floor_image.png"
    with open(floor_info_path) as f:
        train_floor_info = json.load(f)

    fig = utils.visualize_trajectory(
        train_floor_data.waypoint[:, 1:3],
        floor_image_path,
        train_floor_info["map_info"]["width"],
        train_floor_info["map_info"]["height"],
    )

    plot_trajectory_folder = data_folder / "train_trajectories"
    Path(plot_trajectory_folder).mkdir(parents=True, exist_ok=True)
    save_path = plot_trajectory_folder / (fn[:-4] + ".html")
    fig.write_html(str(save_path))
예제 #18
0
#Loads and stores requisite variables.
bank_name = data_map['bank_name']
lei = data_map['lei']
tax_id = data_map['tax_id']
rows_failed = data_map['rows_failed']
rows_total = data_map['rows_total']
passes_all_filepath = data_map['passes_all_filepath']
passes_all_filename = data_map['passes_all_filename']
fails_all_filepath = data_map['fails_all_filepath']
fails_all_filename = data_map['fails_all_filename']
output_filepath = data_map['output_filepath']
output_filename = data_map['output_filename']

#Loading TS and LAR data for the clean file.
passes_ts, passes_lar = utils.read_data_file(path=passes_all_filepath,
                                             data_file=passes_all_filename)

#Loading TS and LAR data for the file that fails a set of edits for every row.
fail_ts, fail_lar = utils.read_data_file(path=fails_all_filepath,
                                         data_file=fails_all_filename)

#Storing the first LAR row of the clean file.
passes_lar_row = passes_lar.iloc[0:1]

#Storing the first LAR row of the file that fails a set of edits for every row.
fails_lar_row = fail_lar.iloc[0:1]

#Checks if the number of rows failed is less than the the number of rows in total.
if rows_failed >= rows_total:
    print("Sorry - rows_failed must be a number less than rows_total.")
예제 #19
0
def test(config):
    # ============= Experiment Folder=============
    output_dir = os.path.join(config['log_dir'], config['name'])
    classifier_output_path = os.path.join(output_dir, 'classifier_output')
    try:
        os.makedirs(classifier_output_path)
    except:
        pass
    past_checkpoint = output_dir
    # ============= Experiment Parameters =============
    BATCH_SIZE = config['batch_size']
    channels = config['num_channel']
    input_size = config['input_size']
    N_CLASSES = config['num_class']
    dataset = config['dataset']
    # in certain circumstances, for example for when classifier has been trained
    # on re-sampled data, we want to still use the whole dataset for the generative model.
    # That's why we produce classifier's output on the test_image_label_dict
    if ('export_image_label_dict'
            in config.keys()) and ('export_train'
                                   in config.keys()) and ('export_test'
                                                          in config.keys()):
        image_label_dict = config['export_image_label_dict']
        train_ids = config['export_train']
        test_ids = config['export_test']
    else:
        image_label_dict = config['image_label_dict']
        train_ids = config['train']
        test_ids = config['test']
    if dataset == 'CelebA':
        pretrained_classifier = celeba_classifier
        my_data_loader = ImageLabelLoader()
    elif dataset == 'shapes':
        pretrained_classifier = shapes_classifier
        my_data_loader = ShapesLoader()
    elif dataset == 'CelebA64' or dataset == 'dermatology':
        pretrained_classifier = celeba_classifier
        my_data_loader = ImageLabelLoader(input_size=64)
    elif dataset == 'synthderm':
        pretrained_classifier = celeba_classifier
        my_data_loader = ImageLabelLoader(input_size=64)
    # ============= Data =============
    try:
        categories, file_names_dict = read_data_file(image_label_dict)
    except:
        print("Problem in reading input data file : ", image_label_dict)
        sys.exit()
    data_train = np.load(train_ids)
    data_test = np.load(test_ids)
    print("The classification categories are: ")
    print(categories)
    print('The size of the training set: ', data_train.shape[0])
    print('The size of the testing set: ', data_test.shape[0])

    # ============= placeholder =============
    with tf.name_scope('input'):
        x_ = tf.placeholder(tf.float32,
                            [None, input_size, input_size, channels],
                            name='x-input')
        y_ = tf.placeholder(tf.int64, [None, N_CLASSES], name='y-input')
        isTrain = tf.placeholder(tf.bool)
    # ============= Model =============

    if N_CLASSES == 1:
        y = tf.reshape(y_, [-1])
        y = tf.one_hot(y, 2, on_value=1.0, off_value=0.0, axis=-1)
        logit, prediction = pretrained_classifier(x_,
                                                  n_label=2,
                                                  reuse=False,
                                                  name='classifier',
                                                  isTrain=isTrain)
    else:
        logit, prediction = pretrained_classifier(x_,
                                                  n_label=N_CLASSES,
                                                  reuse=False,
                                                  name='classifier',
                                                  isTrain=isTrain)
        y = y_
    classif_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=y,
                                                   logits=logit)
    loss = tf.losses.get_total_loss()
    # ============= Variables =============
    # Note that this list of variables only include the weights and biases in the model.
    lst_vars = []
    for v in tf.global_variables():
        lst_vars.append(v)
    # ============= Session =============
    sess = tf.InteractiveSession()
    saver = tf.train.Saver(var_list=lst_vars)
    tf.global_variables_initializer().run()
    # ============= Load Checkpoint =============
    if past_checkpoint is not None:
        ckpt = tf.train.get_checkpoint_state(past_checkpoint + '/')
        if ckpt and ckpt.model_checkpoint_path:
            print(str(ckpt.model_checkpoint_path))
            saver.restore(sess,
                          tf.train.latest_checkpoint(past_checkpoint + '/'))
        else:
            sys.exit()
    else:
        sys.exit()
    # ============= Testing - Save the Output =============

    def get_predictions(data, subset_name):
        names = np.empty([0])
        prediction_y = np.empty([0])
        true_y = np.empty([0])

        num_batch = int(data.shape[0] / BATCH_SIZE)
        for i in range(0, num_batch):
            start = i * BATCH_SIZE
            ns = data[start:start + BATCH_SIZE]
            xs, ys = my_data_loader.load_images_and_labels(
                ns,
                image_dir=config['image_dir'],
                n_class=N_CLASSES,
                file_names_dict=file_names_dict,
                num_channel=channels,
                do_center_crop=True)
            [_pred] = sess.run([prediction],
                               feed_dict={
                                   x_: xs,
                                   isTrain: False,
                                   y_: ys
                               })
            if i == 0:
                names = np.asarray(ns)
                prediction_y = np.asarray(_pred)
                true_y = np.asarray(ys)
            else:
                names = np.append(names, np.asarray(ns), axis=0)
                prediction_y = np.append(prediction_y,
                                         np.asarray(_pred),
                                         axis=0)
                true_y = np.append(true_y, np.asarray(ys), axis=0)
        np.save(classifier_output_path + '/name_{}1.npy'.format(subset_name),
                names)
        np.save(
            classifier_output_path +
            '/prediction_y_{}1.npy'.format(subset_name), prediction_y)
        np.save(classifier_output_path + '/true_y_{}1.npy'.format(subset_name),
                true_y)
        return names, prediction_y, np.reshape(true_y, [-1, N_CLASSES])

    train_names, train_prediction_y, train_true_y = get_predictions(
        data_train, 'train')
    test_names, test_prediction_y, test_true_y = get_predictions(
        data_test, 'test')

    return train_names, train_prediction_y, train_true_y, test_names, test_prediction_y, test_true_y
예제 #20
0
    parser.add_argument("--iou-threshold", type=lambda s: [float(item) for item in s.split(',')],
                        default=[0.3, 0.05, 1.0, 0.05], help='IOU threshold for tracking')
    parser.add_argument("--segment-id", type=str, default='10289507859301986274_4200_000_4220_000',
                        help='segment id to track')
    parser.add_argument("--camera-id", type=str, default='FRONT',
                        help='camera id to track')
    parser.add_argument("--replay", action='store_true', help='replay an already tracked submission')
    parser.add_argument("--write-video", action='store_true', help='create a video from visualization')
    args = parser.parse_args()
    print(args)

    if args.replay:
        # don't apply score threshold for replaying
        args.score_threshold = [0 for _ in args.score_threshold]

    predictions = read_data_file(args.input, args.score_threshold)
    image_id2path = {}
    ground_truth_dir = dirname(args.ground_truth)
    for image in json.load(open(args.ground_truth)):
        image_id2path[image['id']] = join(ground_truth_dir, image['file_name'])

    segment_id, camera_id = args.segment_id, args.camera_id

    tracked_predictions_map = {}
    if args.replay:
        tracked_predictions_map = predictions[segment_id][camera_id]
    else:
        tracked_predictions = track_sort(predictions, segment_id, camera_id,
                                         args.iou_threshold, args.max_age, args.min_hits)
        for e in tracked_predictions:
            frame_id = int(e['image_id'].split('/')[1])
        help="Define whether or not trials must be balanced.")
    options_parser.add_argument("--verbose",
                                "-v",
                                default=False,
                                action="store_true",
                                help="Display trials.")

    options = vars(options_parser.parse_args())
    output_dir = options["output-directory"]
    sep = options["separator"]
    balancing = options["balance"]
    verbose = options["verbose"]

    english_lst = read_lst_file(options["english"])
    french_lst = read_lst_file(options["french"])
    features = read_data_file(options["features"])

    meta = None
    if options["meta_data"]:
        meta = read_meta_file(options["meta_data"],
                              corpus="masseffect" if sep == "," else "skyrim")

    # trials = init_trials(english_lst, french_lst, meta, balancing, True, sep=sep)
    maker = TrialsMaker(english_lst,
                        french_lst,
                        meta=meta,
                        balancing=balancing,
                        shuffle=True,
                        sep=sep)
    trials = maker.make()
#are stored in the custom_file_specifications yaml file under
#"row_by_row_file."

#The following code imports the utils package and yaml library.
import utils
import yaml

#Import custom_file_specifications yaml file.
yaml_file = 'custom_file_specifications.yaml'
with open(yaml_file, 'r') as f:
    custom = yaml.safe_load(f)

#Loads TS and LAR data from the source file and path from
#custom_file_specifications yaml file.
ts_df, lar_df = utils.read_data_file(
    path=custom["row_by_row_file"]["source_filepath"],
    data_file=custom["row_by_row_file"]["source_filename"])

#Creates a file with a number of LAR rows specified in the
#custom_file_specifications yaml file.
lar_df, ts_df = utils.new_lar_rows(
    row_count=custom["row_by_row_file"]["row_count"],
    lar_df=lar_df,
    ts_df=ts_df)

#Modifies the LAR rows based on the modification yaml file specified.
lar_df = utils.row_by_row_modification(
    lar_df, yaml_filepath=custom["row_by_row_file"]["yaml_file"])

#Writes a file to the filename and path in the custom_file_specifications yaml
#file.
def test():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config',
                        '-c',
                        default='configs/celebA_YSBBB_Classifier.yaml')
    args = parser.parse_args()
    # ============= Load config =============
    config_path = args.config
    config = yaml.load(open(config_path))
    print(config)
    # ============= Experiment Folder=============
    output_dir = os.path.join(config['log_dir'], config['name'])
    classifier_output_path = os.path.join(output_dir, 'classifier_output')
    try:
        os.makedirs(classifier_output_path)
    except:
        pass
    past_checkpoint = output_dir
    # ============= Experiment Parameters =============
    BATCH_SIZE = config['batch_size']
    channels = config['num_channel']
    input_size = config['input_size']
    N_CLASSES = config['num_class']
    # ============= Data =============
    try:
        categories, file_names_dict = read_data_file(
            config['image_label_dict'])
    except:
        print("Problem in reading input data file : ",
              config['image_label_dict'])
        sys.exit()
    data_train = np.load(config['train'])
    data_test = np.load(config['test'])
    print("The classification categories are: ")
    print(categories)
    print('The size of the training set: ', data_train.shape[0])
    print('The size of the testing set: ', data_test.shape[0])

    # ============= placeholder =============
    with tf.name_scope('input'):
        x_ = tf.placeholder(tf.float32,
                            [None, input_size, input_size, channels],
                            name='x-input')
        y_ = tf.placeholder(tf.int64, [None, N_CLASSES], name='y-input')
        isTrain = tf.placeholder(tf.bool)
    # ============= Model =============

    if N_CLASSES == 1:
        y = tf.reshape(y_, [-1])
        y = tf.one_hot(y, 2, on_value=1.0, off_value=0.0, axis=-1)
        logit, prediction = pretrained_classifier(x_,
                                                  n_label=2,
                                                  reuse=False,
                                                  name='classifier',
                                                  isTrain=isTrain)
    else:
        logit, prediction = pretrained_classifier(x_,
                                                  n_label=N_CLASSES,
                                                  reuse=False,
                                                  name='classifier',
                                                  isTrain=isTrain)
        y = y_
    classif_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=y,
                                                   logits=logit)
    loss = tf.losses.get_total_loss()
    # ============= Variables =============
    # Note that this list of variables only include the weights and biases in the model.
    lst_vars = []
    for v in tf.global_variables():
        lst_vars.append(v)
    # ============= Session =============
    sess = tf.InteractiveSession()
    saver = tf.train.Saver(var_list=lst_vars)
    tf.global_variables_initializer().run()
    # ============= Load Checkpoint =============
    if past_checkpoint is not None:
        ckpt = tf.train.get_checkpoint_state(past_checkpoint + '/')
        if ckpt and ckpt.model_checkpoint_path:
            print("HERE...................lod checkpoint.........")
            print(str(ckpt.model_checkpoint_path))
            saver.restore(sess,
                          tf.train.latest_checkpoint(past_checkpoint + '/'))
        else:
            sys.exit()
    else:
        sys.exit()
    # ============= Testing Save the Output =============
    names = np.empty([0])
    prediction_y = np.empty([0])
    true_y = np.empty([0])
    for epoch in range(1):
        num_batch = int(data_train.shape[0] / BATCH_SIZE)
        for i in range(0, num_batch):
            start = i * BATCH_SIZE
            ns = data_train[start:start + BATCH_SIZE]
            xs, ys = load_images_and_labels(ns,
                                            config['image_dir'],
                                            N_CLASSES,
                                            file_names_dict,
                                            input_size,
                                            channels,
                                            do_center_crop=True)
            [_pred] = sess.run([prediction],
                               feed_dict={
                                   x_: xs,
                                   isTrain: False,
                                   y_: ys
                               })
            if i == 0:
                names = np.asarray(ns)
                prediction_y = np.asarray(_pred)
                true_y = np.asarray(ys)
            else:
                names = np.append(names, np.asarray(ns), axis=0)
                prediction_y = np.append(prediction_y,
                                         np.asarray(_pred),
                                         axis=0)
                true_y = np.append(true_y, np.asarray(ys), axis=0)
        np.save(classifier_output_path + '/name_train1.npy', names)
        np.save(classifier_output_path + '/prediction_y_train1.npy',
                prediction_y)
        np.save(classifier_output_path + '/true_y_train1.npy', true_y)

        names = np.empty([0])
        prediction_y = np.empty([0])
        true_y = np.empty([0])
        num_batch = int(data_test.shape[0] / BATCH_SIZE)
        for i in range(0, num_batch):
            start = i * BATCH_SIZE
            ns = data_test[start:start + BATCH_SIZE]
            xs, ys = load_images_and_labels(ns,
                                            config['image_dir'],
                                            N_CLASSES,
                                            file_names_dict,
                                            input_size,
                                            channels,
                                            do_center_crop=True)
            [_pred] = sess.run([prediction],
                               feed_dict={
                                   x_: xs,
                                   isTrain: False,
                                   y_: ys
                               })
            if i == 0:
                names = np.asarray(ns)
                prediction_y = np.asarray(_pred)
                true_y = np.asarray(ys)
            else:
                names = np.append(names, np.asarray(ns), axis=0)
                prediction_y = np.append(prediction_y,
                                         np.asarray(_pred),
                                         axis=0)
                true_y = np.append(true_y, np.asarray(ys), axis=0)
        np.save(classifier_output_path + '/name_test1.npy', names)
        np.save(classifier_output_path + '/prediction_y_test1.npy',
                prediction_y)
        np.save(classifier_output_path + '/true_y_test1.npy', true_y)
예제 #24
0
def load_labels(config, type):
    lines = read_data_file(config.get_data_path(type))
    labels = [emotion2label[line[4]] for line in lines]
    return to_categorical(np.asarray(labels))
# # Config

# In[17]:

writer = SummaryWriter('/pylon5/ac5616p/debdas/Explanation/tensorboard')
NUMS_CLASS = 2
BATCH_SIZE = 8
input_size = 128
EPOCHS = 10
channels = 3
lambda_GAN = 1
lambda_cyc = 100
image_dir = '/pghbio/dbmi/batmanlab/singla/MICCAI_2019/GAN_Interpretability/data/celebA/images/'
image_label_dict = '/pylon5/ac5616p/debdas/Explanation/data/CelebA/Young_binary_classification.txt'  # Smiling_binary_classification.txt
try:
    categories, file_names_dict = read_data_file(image_label_dict, image_dir)
except:
    print("Problem in reading input data file : ", image_label_dict)
    sys.exit()
data = np.asarray(list(file_names_dict.keys()))

# CUDA setting
if not torch.cuda.is_available():
    raise ValueError("Should buy GPU!")
torch.manual_seed(46)
torch.cuda.manual_seed_all(46)
device = torch.device('cuda')
torch.set_default_tensor_type('torch.cuda.FloatTensor')
torch.backends.cudnn.benchmark = True

# dataset
#load config data from yaml file
with open('configurations/clean_file_config.yaml') as f:
    # use safe_load instead load
    data_map = yaml.safe_load(f)

with open('configurations/test_filepaths.yaml') as f:
    filepaths = yaml.safe_load(f)

ts_schema = pd.DataFrame(json.load(open(
    filepaths['ts_schema_json'])))  #load TS schema
lar_schema = pd.DataFrame(json.load(open(
    filepaths['lar_schema_json'])))  #load LAR schema
#instantiate test_file_generator.py to modify clean data so that the resulting files fail specific edits
file_maker = test_data(ts_schema=ts_schema,
                       lar_schema=lar_schema)  #instantiate edit file maker

ts_data, lar_data = utils.read_data_file(
    path=filepaths['clean_filepath'].format(
        bank_name=data_map["name"]["value"]),
    data_file=data_map["clean_file"]["value"])  #load clean data file
file_maker.load_data_frames(
    ts_data, lar_data)  #pass clean file data to file maker object
#generate a file for each edit function in file maker
edits = []
for func in dir(file_maker):  #loop over all data modification functions
    if func[:1] in ("s", "v", "q") and func[1:4].isdigit(
    ) == True:  #check if function is a numbered syntax or validity edit
        print("applying:", func)
        getattr(file_maker,
                func)()  #apply data modification functions and produce files
예제 #27
0
# run pagerank
print("Ranking\t3\tExecuting Ranking Algorithm", flush=True)
PR = pagerank_power(G, p=alpha, tol=tol)

# sorting output
print("Ranking\t4\tSorting Results", flush=True)
sorted_indices = np.argsort(PR)[::-1][:len(PR)]

# filter out indices that are not in A
indices = np.unique(np.concatenate((A[:, 0], A[:, 1])))
sorted_indices = sorted_indices[np.isin(sorted_indices, indices)]

# write results to output file
print("Ranking\t5\tWriting Results", flush=True)

if src_field != "null":
    idx_file = nodes_dir + metapath[0] + ".csv"
    idx = utils.read_data_file(idx_file, src_field)

with open(outfile, 'w', newline='') as csvfile:
    filewriter = csv.writer(csvfile,
                            delimiter='\t',
                            quotechar='"',
                            quoting=csv.QUOTE_MINIMAL)
    for i in sorted_indices:
        if src_field != "null":
            filewriter.writerow([str(i), idx[str(i)], PR[i]])
        else:
            filewriter.writerow([str(i), PR[i]])
예제 #28
0
def get_source_city_info(source_city_name):
    file_name = [i for i in LIST_CITY_DATA_FILE_NAME
                 if source_city_name in i][0]
    X_source = read_data_file(file_name=file_name, data_type='artifact_app')
    list_sources_venues = list(X_source[col_grain].values)
    return X_source, list_sources_venues
예제 #29
0
def read_dest_city_data(dest_city_name):
    file_name = [i for i in LIST_CITY_DATA_FILE_NAME if dest_city_name in i][0]
    X_dest = read_data_file(file_name=file_name, data_type='artifact_app')
    return X_dest
예제 #30
0
def train():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', '-c', type=str)
    parser.add_argument('--debug', '-d', action='store_true')
    args = parser.parse_args()

    # ============= Load config =============
    config_path = args.config
    config = yaml.load(open(config_path))
    print(config)

    # ============= Experiment Folder=============
    assets_dir = os.path.join(config['log_dir'], config['name'])
    log_dir = os.path.join(assets_dir, 'log')
    ckpt_dir = os.path.join(assets_dir, 'ckpt_dir')
    sample_dir = os.path.join(assets_dir, 'sample')
    test_dir = os.path.join(assets_dir, 'test')
    # make directory if not exist
    try:
        os.makedirs(log_dir)
    except:
        pass
    try:
        os.makedirs(ckpt_dir)
    except:
        pass
    try:
        os.makedirs(sample_dir)
    except:
        pass
    try:
        os.makedirs(test_dir)
    except:
        pass

    # ============= Experiment Parameters =============
    ckpt_dir_cls = config['cls_experiment']
    BATCH_SIZE = config['batch_size']
    EPOCHS = config['epochs']
    channels = config['num_channel']
    input_size = config['input_size']
    NUMS_CLASS_cls = config['num_class']
    NUMS_CLASS = config['num_bins']
    MU_CLUSTER = config['mu_cluster']
    VAR_CLUSTER = config['var_cluster']
    TRAVERSAL_N_SIGMA = config['traversal_n_sigma']
    STEP_SIZE = 2*TRAVERSAL_N_SIGMA * VAR_CLUSTER/(NUMS_CLASS - 1)
    OFFSET = MU_CLUSTER - TRAVERSAL_N_SIGMA*VAR_CLUSTER
    target_class = config['target_class']

    # CSVAE parameters
    beta1 = config['beta1']
    beta2 = config['beta2']
    beta3 = config['beta3']
    beta4 = config['beta4']
    beta5 = config['beta5']
    z_dim = config['z_dim']
    w_dim = config['w_dim']

    save_summary = int(config['save_summary'])
    save_ckpt = int(config['save_ckpt'])
    ckpt_dir_continue = config['ckpt_dir_continue']

    dataset = config['dataset']
    if dataset == 'CelebA':
        pretrained_classifier = celeba_classifier
        my_data_loader = ImageLabelLoader()
        EncoderZ = EncoderZ_128
        EncoderW = EncoderW_128
        DecoderX = DecoderX_128
        DecoderY = DecoderY_128

    elif dataset == 'shapes':
        pretrained_classifier = shapes_classifier
        if args.debug:
            my_data_loader = ShapesLoader(dbg_mode=True, dbg_size=config['batch_size'],
                                          dbg_image_label_dict=config['image_label_dict'])
        else:
            my_data_loader = ShapesLoader()
        EncoderZ = EncoderZ_64
        EncoderW = EncoderW_64
        DecoderX = DecoderX_64
        DecoderY = DecoderY_64

    elif dataset == 'CelebA64' or dataset == 'dermatology':
        pretrained_classifier = celeba_classifier
        my_data_loader = ImageLabelLoader(input_size=64)
        EncoderZ = EncoderZ_64
        EncoderW = EncoderW_64
        DecoderX = DecoderX_64
        DecoderY = DecoderY_64
    elif dataset == 'synthderm':
        pretrained_classifier = celeba_classifier
        my_data_loader = ImageLabelLoader(input_size=64)
        EncoderZ = EncoderZ_64
        EncoderW = EncoderW_64
        DecoderX = DecoderX_64
        DecoderY = DecoderY_64

    if ckpt_dir_continue == '':
        continue_train = False
    else:
        ckpt_dir_continue = os.path.join(ckpt_dir_continue, 'ckpt_dir')
        continue_train = True

    global_step = tf.Variable(0, dtype=tf.int32, trainable=False, name='global_step')

    # ============= Data =============
    try:
        categories, file_names_dict = read_data_file(config['image_label_dict'])
    except:
        print("Problem in reading input data file : ", config['image_label_dict'])
        sys.exit()
    data = np.asarray(list(file_names_dict.keys()))

    # CSVAE does not need discretizing categories. The default 2 is recommended.
    print("The classification categories are: ")
    print(categories)
    print('The size of the training set: ', data.shape[0])
    fp = open(os.path.join(log_dir, 'setting.txt'), 'w')
    fp.write('config_file:' + str(config_path) + '\n')
    fp.close()

    # ============= placeholder =============
    x_source = tf.placeholder(tf.float32, [None, input_size, input_size, channels], name='x_source')
    y_s = tf.placeholder(tf.int32, [None, NUMS_CLASS_cls], name='y_s')
    y_source = y_s[:, NUMS_CLASS_cls-1]
    train_phase = tf.placeholder(tf.bool, name='train_phase')

    y_target = tf.placeholder(tf.int32, [None, w_dim], name='y_target')  # between 0 and NUMS_CLASS

    # ============= CSVAE =============

    encoder_z = EncoderZ('encoder_z')
    encoder_w = EncoderW('encoder_w')
    decoder_x = DecoderX('decoder_x')
    decoder_y = DecoderY('decoder_y')

    # encode x to get mean, log variance, and samples from the latent subspace Z
    mu_z, logvar_z, z = encoder_z(x_source, z_dim)
    # encode x and y to get mean, log variance, and samples from the latent subspace W
    mu_w, logvar_w, w = encoder_w(x_source, y_source, w_dim)

    # pass samples of z and w to get predictions of x
    pred_x = decoder_x(tf.concat([w, z], axis=-1))
    # get predicted labels based only on the latent subspace Z
    pred_y = decoder_y(z, NUMS_CLASS_cls)

    # Create and save a grid of images
    fake_img_traversal = tf.zeros([0, input_size, input_size, channels])
    for i in range(w_dim):
        for j in range(NUMS_CLASS):
            val = j * STEP_SIZE
            np_arr = np.zeros((BATCH_SIZE, w_dim))
            np_arr[:, i] = val
            tmp_w = tf.convert_to_tensor(np_arr, dtype=tf.float32)
            fake_img = decoder_x(tf.concat([tmp_w, z], axis=-1))
            fake_img_traversal = tf.concat([fake_img_traversal, fake_img], axis=0)
    fake_img_traversal_board = make4d_tensor(fake_img_traversal, channels, input_size, w_dim, NUMS_CLASS, BATCH_SIZE)
    fake_img_traversal_save = make3d_tensor(fake_img_traversal, channels, input_size, w_dim, NUMS_CLASS, BATCH_SIZE)

    # Create and save 2d traversal, this is relevant only for w_dim == 2
    fake_2d_img_traversal = tf.zeros([0, input_size, input_size, channels])
    for i in range(NUMS_CLASS):
        for j in range(NUMS_CLASS):
            val_0 = i * STEP_SIZE
            val_1 = j * STEP_SIZE
            np_arr = np.zeros((BATCH_SIZE, w_dim))
            np_arr[:, 0] = val_0
            np_arr[:, 1] = val_1
            tmp_w = tf.convert_to_tensor(np_arr, dtype=tf.float32)
            fake_2d_img = decoder_x(tf.concat([tmp_w, z], axis=-1))
            fake_2d_img_traversal = tf.concat([fake_2d_img_traversal, fake_2d_img], axis=0)
    fake_2d_img_traversal_board = make4d_tensor(fake_2d_img_traversal, channels, input_size, NUMS_CLASS, NUMS_CLASS, BATCH_SIZE)
    fake_2d_img_traversal_save = make3d_tensor(fake_2d_img_traversal, channels, input_size, NUMS_CLASS, NUMS_CLASS, BATCH_SIZE)

    # Create a single image based on y_target
    target_w = STEP_SIZE * tf.cast(y_target, dtype=tf.float32) + OFFSET
    fake_target_img = decoder_x(tf.concat([target_w, z], axis=-1))

    # ============= pre-trained classifier =============

    real_img_cls_logit_pretrained, real_img_cls_prediction = pretrained_classifier(x_source, NUMS_CLASS_cls,
                                                                                   reuse=False, name='classifier')
    fake_recon_cls_logit_pretrained, fake_recon_cls_prediction = pretrained_classifier(pred_x, NUMS_CLASS_cls,
                                                                                       reuse=True)
    fake_img_cls_logit_pretrained, fake_img_cls_prediction = pretrained_classifier(fake_img, NUMS_CLASS_cls,
                                                                                   reuse=True)

    # ============= predicted probabilities =============
    fake_target_p_tensor = tf.reduce_max(tf.cast(y_target, tf.float32) * 1.0 / float(NUMS_CLASS - 1), axis=1)

    # ============= Loss =============
    # OPTIMIZATION:

    # Specified in section 4.1 of http://www.cs.toronto.edu/~zemel/documents/Conditional_Subspace_VAE_all.pdf
    # There are three components: M1, M2, N

    # 1.Optimize the first loss related to maximizing variational lower bound
    #   on the marginal log likelihood and minimizing mutual information

    # define two KL divergences:
    # KL divergence for label 1
    #    We want the latent subspace W for this label to be close to mean 0, var 0.01
    kl1 = KL(mu1=mu_w, logvar1=logvar_w,
             mu2=tf.zeros_like(mu_w), logvar2=tf.ones_like(logvar_w) * np.log(0.01))
    # KL divergence for label 0
    #    We want the latent subspace W for this label to be close to mean MU_CLUSTER, var VAR_CLUSTER
    kl0 = KL(mu1=mu_w, logvar1=logvar_w, mu2=tf.ones_like(mu_w) * MU_CLUSTER, logvar2=tf.ones_like(logvar_w) * np.log(VAR_CLUSTER))

    loss_m1_1 = tf.reduce_sum(beta1 * tf.reduce_sum((x_source - pred_x) ** 2, axis=-1))  # corresponds to M1
    loss_m1_2 = tf.reduce_sum(
        beta2 * tf.where(tf.equal(y_source, tf.ones_like(y_source)), kl1, kl0))  # corresponds to M1
    loss_m1_3 = tf.reduce_sum(
        beta3 * KL(mu_z, logvar_z, tf.zeros_like(mu_z), tf.zeros_like(logvar_z)))  # corresponds to M1
    loss_m2 = tf.reduce_sum(beta4 * tf.reduce_sum(pred_y * safe_log(pred_y), axis=-1))  # corresponds to M2

    loss_m1 = loss_m1_1 + loss_m1_2 + loss_m1_3
    loss1 = loss_m1 + loss_m2

    # 2. Optimize second loss related to learning the approximate posterior

    loss_n = tf.reduce_sum(beta5 * tf.where(y_source == 1, -safe_log(pred_y[:, 1]), -safe_log(pred_y[:, 0])))  # N

    loss2 = loss_n

    optimizer_1 = tf.train.AdamOptimizer(2e-4, beta1=0., beta2=0.9).minimize(loss1, var_list=decoder_x.var_list() +
                                                                                             encoder_w.var_list() +
                                                                                             encoder_z.var_list(),
                                                                             global_step=global_step)
    optimizer_2 = tf.train.AdamOptimizer(2e-4, beta1=0., beta2=0.9).minimize(loss2, var_list=decoder_y.var_list(),
                                                                             global_step=global_step)

    # combine losses for tracking
    loss = loss1 + loss2

    # ============= summary =============
    real_img_sum = tf.summary.image('real_img', x_source)
    fake_recon_img_sum = tf.summary.image('fake_recon_img', pred_x)
    fake_img_sum = tf.summary.image('fake_target_img', fake_target_img)
    fake_img_traversal_sum = tf.summary.image('fake_img_traversal', fake_img_traversal_board)
    fake_2d_img_traversal_sum = tf.summary.image('fake_2d_img_traversal', fake_2d_img_traversal_board)

    loss_m1_sum = tf.summary.scalar('losses/M1', loss_m1)
    loss_m1_1_sum = tf.summary.scalar('losses/M1/m1_1', loss_m1_1)
    loss_m1_2_sum = tf.summary.scalar('losses/M1/m1_2', loss_m1_2)
    loss_m1_3_sum = tf.summary.scalar('losses/M1/m1_3', loss_m1_3)
    loss_m2_sum = tf.summary.scalar('losses/M2', loss_m2)
    loss_n_sum = tf.summary.scalar('losses/N', loss_n)
    loss_sum = tf.summary.scalar('losses/total_loss', loss)

    part1_sum = tf.summary.merge(
        [loss_m1_sum, loss_m1_1_sum, loss_m1_2_sum, loss_m1_3_sum, loss_m2_sum])
    part2_sum = tf.summary.merge(
        [loss_n_sum, loss_sum, ])
    overall_sum = tf.summary.merge(
        [loss_sum, real_img_sum, fake_recon_img_sum, fake_img_sum, fake_img_traversal_sum, fake_2d_img_traversal_sum])

    # ============= session =============
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()

    writer = tf.summary.FileWriter(log_dir, sess.graph)

    # ============= Checkpoints =============
    if continue_train:
        print(" [*] before training, Load checkpoint ")
        print(" [*] Reading checkpoint...")

        ckpt = tf.train.get_checkpoint_state(ckpt_dir_continue)
        if ckpt and ckpt.model_checkpoint_path:
            ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
            saver.restore(sess, os.path.join(ckpt_dir_continue, ckpt_name))
            print(ckpt_dir_continue, ckpt_name)
            print("Successful checkpoint upload")
        else:
            print("Failed checkpoint load")
    else:
        print(" [!] before training, no need to Load ")

    # ============= load pre-trained classifier checkpoint =============
    class_vars = [var for var in slim.get_variables_to_restore() if 'classifier' in var.name]
    name_to_var_map_local = {var.op.name: var for var in class_vars}
    temp_saver = tf.train.Saver(var_list=name_to_var_map_local)
    ckpt = tf.train.get_checkpoint_state(ckpt_dir_cls)
    ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
    temp_saver.restore(sess, os.path.join(ckpt_dir_cls, ckpt_name))
    print("Classifier checkpoint loaded.................")
    print(ckpt_dir_cls, ckpt_name)

    # ============= Training =============
    for e in range(1, EPOCHS + 1):
        np.random.shuffle(data)
        for i in range(data.shape[0] // BATCH_SIZE):
            if args.debug:
                image_paths = np.array([str(ind) for ind in my_data_loader.tmp_list])
            else:
                image_paths = data[i * BATCH_SIZE:(i + 1) * BATCH_SIZE]
            img, labels = my_data_loader.load_images_and_labels(image_paths, image_dir=config['image_dir'], n_class=1,
                                                                file_names_dict=file_names_dict,
                                                                num_channel=channels, do_center_crop=True)

            labels = labels.ravel()
            labels = np.eye(NUMS_CLASS_cls)[labels.astype(int)]

            target_labels_probs = np.random.randint(0, high=NUMS_CLASS, size=BATCH_SIZE)
            target_labels_w_ind = np.random.randint(0, high=w_dim, size=BATCH_SIZE)
            target_labels = np.eye(w_dim)[target_labels_w_ind] * np.repeat(np.expand_dims(target_labels_probs, axis=-1), w_dim, axis=1)

            my_feed_dict = {y_target: target_labels, x_source: img, train_phase: True, y_s: labels}

            _, par1_loss, par1_summary_str, overall_sum_str, counter = sess.run([optimizer_1, loss1, part1_sum, overall_sum, global_step],
                                                                       feed_dict=my_feed_dict)

            writer.add_summary(par1_summary_str, global_step=counter)
            writer.add_summary(overall_sum_str, global_step=counter)

            _, part2_loss, part2_summary_str, overall_sum_str2, counter = sess.run([optimizer_2, loss2, part2_sum, overall_sum, global_step],
                                                                          feed_dict=my_feed_dict)
            writer.add_summary(part2_summary_str, global_step=counter)
            writer.add_summary(overall_sum_str2, global_step=counter)

            def save_results(sess, step):
                num_seed_imgs = BATCH_SIZE
                img, labels = my_data_loader.load_images_and_labels(image_paths[0:num_seed_imgs],
                                                                    image_dir=config['image_dir'], n_class=1,
                                                                    file_names_dict=file_names_dict,
                                                                    num_channel=channels,
                                                                    do_center_crop=True)

                labels = labels.ravel()
                labels = np.eye(NUMS_CLASS_cls)[labels.astype(int)]

                target_labels_probs = np.random.randint(0, high=NUMS_CLASS, size=BATCH_SIZE)
                target_labels_w_ind = np.random.randint(0, high=w_dim, size=BATCH_SIZE)
                target_labels = np.eye(w_dim)[target_labels_w_ind] * np.repeat(
                    np.expand_dims(target_labels_probs, axis=-1), w_dim, axis=1)

                my_feed_dict = {y_target: target_labels, x_source: img, train_phase: False,
                                y_s: labels}

                sample_fake_img_traversal, sample_fake_2d_img_traversal = sess.run([fake_img_traversal_save, fake_2d_img_traversal_save], feed_dict=my_feed_dict)

                # save samples
                sample_file = os.path.join(sample_dir, '%06d.jpg' % step)
                save_image(sample_fake_img_traversal, sample_file)

                sample_file = os.path.join(sample_dir, '%06d_2d.jpg' % step)
                save_image(sample_fake_2d_img_traversal, sample_file)

            batch_counter = int(counter/2)
            if batch_counter % save_summary == 0:
                save_results(sess, batch_counter)

            if batch_counter % save_ckpt == 0:
                saver.save(sess, ckpt_dir + "/model%2d.ckpt" % batch_counter, global_step=global_step)