コード例 #1
0
def main(data_folder_path, segment_folder_path):
    # Parse the paths and check if the directories exist or not
    data_folder = check_directory(data_folder_path)
    segment_folder = check_directory(segment_folder_path)

    # To get the files, we use glob, sorting them to avoid problems.
    file_formats = (
        ("headers", ".vhdr"),
        # ("markers", ".vrmk"),
        # ("data", ".dat"),
        ("logs", ".txt"))
    file_paths = {}

    for name, ending in file_formats:
        file_paths[name] = sorted(
            glob.glob(os.path.join(data_folder, "*" + ending)))

    # Create progress bar to make script more friendly
    bar = Bar('Segmenting data', max=len(file_paths["headers"]) * 2)

    # Iterate until no more files are left
    while len(file_paths["headers"]) > 0:
        bar.next()
        current_files = synched_pop(file_paths)
        wrapper = BrainvisionWrapper(current_files["headers"],
                                     current_files["logs"])
        wrapper.segment_data()
        wrapper.save_segmented_data(segment_folder)
        bar.next()
    bar.finish()
コード例 #2
0
def process_file(fname, f_list, p_list, directory):
    modified_data = []
    fieldnames = [x for x in f_list]
    for vals in p_list:
        fieldnames.append(vals)
    fieldnames.insert(0, 'UNITID')

    with io.open(fname + '.csv', 'r', encoding='utf-8-sig') as csvfile:
        reader = csv.DictReader((l.encode('utf-8') for l in csvfile))
        for row in reader:
            modified_row = {}
            for field in fieldnames:
                modified_row[field] = row[field]
            modified_data.append(modified_row)

    # The processed files are written to the output directory.
    output_dir = directory + 'Processed/'
    utils.check_directory(output_dir)

    write_file = output_dir + fname + '_pruned.csv'
    with open(write_file, 'w') as csvfile:
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        writer.writeheader()

        for row in modified_data:
            writer.writerow(row)

    return output_dir
コード例 #3
0
    def save_segmented_data(self, segments_folder_path: str):
        """Saves segmented data to json files for further processing.

        Parameters
        ----------
        segments_folder_path: str
            path to the folder where the segments will be stored
        """

        assert self.dat.data_segments != [], ("Data must be segmented" +
                                              " before you save it.")

        check_directory(segments_folder_path)

        # Create a bar to display progress of saving to files
        bar = IncrementalBar("Saving segments",
                             max=len(self.dat.data_segments))

        segment_number = 0

        for sgmnt in self.dat.data_segments:

            file_name = "sujeto_{}_segmento_{}.json".format(
                str(self.log.trials[0]["subject"]), segment_number + 1)

            file_path = os.path.join(segments_folder_path, file_name)

            # We have to "flip inside-out" the metadata and data
            metadata = sgmnt["metadata"]

            json_data = {
                "face_code": metadata["face_code"],
                "audio_code": metadata["audio_code"],
                "audio_file_name":
                self.log.trials[segment_number]["soundfile"],
                "audio_start_position":
                metadata["audio_start"] - metadata["start"],
                "audio_end_position":
                metadata["audio_end"] - metadata["start"],
                "trigger_position": metadata["stim_pos"] - metadata["start"],
                "channels": {}
            }

            for key in sgmnt.keys():
                if key == "metadata":
                    continue
                json_data["channels"][key] = sgmnt[key].tolist()

            with open(file_path, "w") as json_file:
                json.dump(json_data, json_file)

            segment_number += 1

            bar.next()
        bar.finish()
コード例 #4
0
def main():
    np.random.seed(0)

    check_directory(RESULT_DIR)
    check_directory(TFLOG_DIR)

    print('Build model.....', end='')
    model = build_model()
    print('Done!')

    print('Load file list in dataset.....', end='')
    train_gen, valid_gen = load_data()
    print('Done!')

    fpath = os.path.join(
        RESULT_DIR,
        'model.{epoch:04d}-{loss:.5f}-{acc:.3f}-{val_loss:.5f}-{val_acc:.3f}.h5'
    )
    checkpoint = ModelCheckpoint(fpath,
                                 monitor='val_loss',
                                 verbose=0,
                                 save_best_only=True,
                                 save_weights_only=False,
                                 mode='auto',
                                 period=1)
    lr_scheduler = LearningRateScheduler(schedule)
    early_stopping = EarlyStopping(monitor='val_loss', patience=30, verbose=1)
    tensor_board = TensorBoard(log_dir=TFLOG_DIR,
                               histogram_freq=0,
                               batch_size=BATCH_SIZE,
                               write_graph=False,
                               write_grads=True,
                               write_images=True)

    history = model.fit_generator(
        train_gen,
        steps_per_epoch=len(train_gen),
        epochs=EPOCH_COUNT,
        validation_data=valid_gen,
        validation_steps=len(valid_gen),
        callbacks=[checkpoint, lr_scheduler, early_stopping, tensor_board],
        workers=20,
        max_queue_size=500,
        use_multiprocessing=True,
        verbose=1)

    print('Save training result.....', end='')
    with open(os.path.join(RESULT_DIR, 'history.pickle'), 'wb') as f:
        pickle.dump(history.history, f)
    model.save(os.path.join(RESULT_DIR, 'model.h5'))
    print('Done!')
コード例 #5
0
def remove_rename(args):
    print('Crop start')
    start = time.time()  # 시작 시간 저장
    # save_path check
    utils.check_directory(args.save_path)

    # read files & sorting
    annotation_files = os.listdir(args.annotation_path)
    images_files = os.listdir(args.image_path)
    annotation_files_sort = sorted(annotation_files)
    images_files_sort = sorted(images_files)
    assert (len(annotation_files_sort) != len(images_files_sort),
            '파일 개수가 맞지 않음 anno : {0}, images : {1}'.format(
                len(annotation_files), len(images_files)))

    orgin_ano = []
    orgin_img = []
    re_ano = []
    re_img = []
    for i in annotation_files_sort:
        annotation_file = utils.tag_remove_parser(i)
        orgin_ano.append(i)
        re_ano.append(annotation_file)

    for i in images_files_sort:
        images_file = utils.tag_remove_parser(i)
        orgin_img.append(i)
        re_img.append(images_file)

    orgin_ano = np.array(orgin_ano)
    orgin_img = np.array(orgin_img)
    re_ano = np.array(re_ano)
    re_img = np.array(re_img)

    # Remove file
    utils.remove_files(re_ano, re_img, orgin_img, args.image_path)
    utils.remove_files(re_img, re_ano, orgin_ano, args.annotation_path)

    # Modify file name
    for i in annotation_files_sort:
        annotation_file = utils.tag_remove_parser(i)
        os.rename(os.path.join(args.annotation_path, i),
                  os.path.join(args.annotation_path, annotation_file + '.xml'))

    for i in images_files_sort:
        images_file = utils.tag_remove_parser(i)
        os.rename(os.path.join(args.image_path, i),
                  os.path.join(args.image_path, images_file + '.jpg'))

    print("Preprocessing time :", time.time() - start)
コード例 #6
0
ファイル: analyze.py プロジェクト: tianyi21/scdr
def run_dr(data, dr_type, epoch=None, args=None, dim=2, cache=True):
    """
    Running TSNE/UMAP/PCA DR
    :param data: Original data
    :param dr_type: DR type TSNE/UMAP/PCA
    :param epoch: (Optional) -> DR for NN embedding
    :param args: (Optional) -> filtered data
    :param dim: DR dimension
    :param cache: Writing DR cache
    :return: embedding
    """
    def _dr(_data, _dr_type, _dim):
        print(">>> Running {} embedding".format(_dr_type.upper()))
        if np.array(_data).shape[1] == _dim:
            return _data
        tic = time.time()
        if _dr_type.upper() == "TSNE":
            _embedding = TSNE(n_components=2, n_jobs=TSNE_JOBS).fit_transform(_data)
        elif _dr_type.upper() == "UMAP":
            _embedding = umap.UMAP(n_neighbors=5, n_components=_dim).fit_transform(_data)
        elif _dr_type.upper() == "PCA":
            _embedding = PCA(n_components=_dim).fit_transform(data)
        else:
            raise Exception("!!! Invalid dr_type provided.")
        toc = time.time()
        print("    {} took {:.2f} s".format(_dr_type.upper(), toc - tic))
        return _embedding

    if epoch is None:
        file_name = dr_type.lower() + "_data_cache_"
        # Check cache
        if args is None and os.path.isfile(os.path.join(CACHE_DIR, file_name + "{}d.pkl".format(dim))):
            embedding = cache_operation(
                os.path.join(CACHE_DIR, file_name + "{}d.pkl".format(dim)), "read", text=" " + dr_type.upper())
        elif args is not None and os.path.isfile(os.path.join(CACHE_DIR, file_name + "{}m_{}s_{}d.pkl".format(
                args.mean_filter, args.sd_filter, dim))):
            embedding = cache_operation(os.path.join(CACHE_DIR, file_name + "{}m_{}s_{}d.pkl".format(
                    args.mean_filter, args.sd_filter, dim)), "read", text=" " + dr_type.upper())
        else:
            embedding = _dr(data, dr_type, dim)
            if cache:
                check_directory(CACHE_DIR)
                write_name = file_name + "{}d.pkl".format(
                    dim) if args is None else file_name + "{}m_{}s_{}d.pkl".format(
                    args.mean_filter, args.sd_filter, dim)
                cache_operation(os.path.join(CACHE_DIR, write_name), "write", embedding, text=" " + dr_type.upper())
    else:
        embedding = _dr(data, dr_type, dim)
    return embedding
コード例 #7
0
def check_db_config(config, allow_missing=False):

    if is_missing(config, 'host'):
        config['host'] = '127.0.0.1'
    if is_missing(config, 'dbname'):
        utils.print_error("Missing DB name")
        return False
    if is_missing(config, 'user'):
        if allow_missing:
            config['user'] = None
        else:
            utils.print_error("Missing DB user")
            return False
    if is_missing(config, 'password'):
        if allow_missing:
            config['password'] = None
        else:
            utils.print_error("Missing DB password")
            return False
    if is_missing(config, 'pointer'):
        utils.print_error("Missing DB pointer")
        return False

    config['sql'] = None
    if not is_missing(config, 'sql_output'):
        config['sql'] = utils.file_dir_name(
            config['sql_output']) + utils.file_file_name(config['sql_output'])
        config['dir'] = utils.file_dir_name(config['sql'])
        if not utils.check_directory(config['dir']):
            utils.print_error("Couldn't create directory for output " +
                              config["name"])
            return False

    return True
コード例 #8
0
def generate_college_csv(f_list, p_list, directory):
    fieldnames = [x for x in f_list]
    for vals in p_list:
        fieldnames.append(vals)
    fieldnames.insert(0, 'YEAR')
    fieldnames.insert(0, 'UNITID')

    complete_data = {}
    complete_data = defaultdict(lambda: {}, complete_data)

    # In order to impute missing values for a particular college across years, we need to map
    # filenames to corresponding categorical values.
    fnameyear = utils.fnameyear
    # Changing the directory to the path containing the pruned_csv_files.
    os.chdir(directory)
    # reading information from all the 19 years and storing collegewise data into 'complete_data' dictionary.
    for file_name in glob.glob('*.csv'):
        with io.open(file_name, 'r', encoding='utf-8-sig') as csvfile:
            reader = csv.DictReader((l.encode('utf-8') for l in csvfile))
            key = file_name.split('.')
            k = key[0]
            # Mapping the filename to the corresponding categorical value.
            year = fnameyear[k]
            for row in reader:
                row['YEAR'] = year
                temp = complete_data[row['UNITID']]
                temp[year] = row
                complete_data[row['UNITID']] = temp

    # Writing back the college information into file. These files will be used for further processing.
    output_dir = directory + 'CollegeData/'
    utils.check_directory(output_dir)
    for key, value in complete_data.items():
        write_file = output_dir + key + '.csv'
        with open(write_file, 'w') as newfile:
            writer = csv.DictWriter(newfile, fieldnames=fieldnames)
            writer.writeheader()
            for year, vals in value.items():
                writer.writerow(vals)

    return output_dir
コード例 #9
0
def main(from_dir_str, to_dir_str, factor_str):

    from_dir = check_directory(from_dir_str)
    to_dir = check_directory(to_dir_str)
    factor = int(factor_str)

    file_paths = glob.glob(os.path.join(from_dir, "*.wav"))

    bar = Bar('Downsampling audio files', max=len(file_paths))

    rate_warning = False  # Flag for skipping rate check

    for path in file_paths:
        filename = os.path.split(path)[1]

        rate, data = wavfile.read(path)

        if rate % factor != 0:
            if not rate_warning:
                print("WARNING: Rate should be multiple of factor")
                print(f"Rate: {rate}, new rate would be {rate/factor}.")
                a = input(f"Do you want to use {rate//factor} as rate? [y/n] ")
                if a == "y":
                    rate_warning = True
                    print(f"Using {rate//factor} as new rate.")
                else:
                    print("Aborting...")
                    sys.exit(1)

        data = data[::factor]
        rate = rate // factor

        wavfile.write(os.path.join(to_dir, filename), rate, data)
        bar.next()

    bar.finish()
    print("Audio files downsampled.")
コード例 #10
0
def check_file_config(config):

    if is_missing(config, 'dir'):
        config['dir'] = cfg.ini_dir
    config['dir'] = utils.full_dir_name(config['dir'])

    if not utils.check_directory(config['dir']):
        utils.print_error("Couldn't create directory for output " +
                          config["name"])
        return False

    if is_missing(config, 'ptr'):
        config['ptr'] = config['dir'] + config["name"] + ".ptr"
    if is_missing(config, 'out'):
        config['out'] = config['dir'] + config["name"] + ".out"

    return True
コード例 #11
0
def preprocess_images(old_dataset_folder, new_dataset_folder):
    execution_dir = getcwd()
    detector = caffe_get_detector(
        join(execution_dir, 'static', 'models', 'MobileNetSSD',
             'MobileNetSSD_deploy.prototxt'),
        join(execution_dir, 'static', 'models', 'MobileNetSSD',
             'MobileNetSSD_deploy.caffemodel'))
    old_color_folder = join(old_dataset_folder, 'color')
    old_depth_folder = join(old_dataset_folder, 'depth')
    new_color_folder = join(new_dataset_folder, 'color')
    new_depth_folder = join(new_dataset_folder, 'depth')
    check_directory(new_dataset_folder)
    check_directory(new_color_folder)
    check_directory(new_depth_folder)
    color_filenames = listdir(old_color_folder)
    depth_filenames = listdir(old_depth_folder)
    for color_filename, depth_filename in zip(color_filenames,
                                              depth_filenames):
        old_color_filepath = join(old_color_folder, color_filename)
        old_depth_filepath = join(old_depth_folder, depth_filename)
        new_color_filepath = join(new_color_folder, color_filename)
        new_depth_filepath = join(new_depth_folder, depth_filename)
        col_min, col_max, row_min, row_max = caffe_detect_body(
            detector=detector, image_path=old_color_filepath)
        print(f"Detect body: {old_color_filepath}")
        info_dict = {
            'col_min': col_min,
            'col_max': col_max,
            'row_min': row_min,
            'row_max': row_max,
            'old_color_filepath': old_color_filepath,
            'new_color_filepath': new_color_filepath,
            'old_depth_filepath': old_depth_filepath,
            'new_depth_filepath': new_depth_filepath,
        }
        thread = Thread(target=_tailor_image_pair, args=(info_dict, ))
        thread.start()
    thread.join()
    time.sleep(3)
コード例 #12
0
def main():
    # preprocessing
    # 매칭되지 않는 파일 삭제 및 파일명 구조 통일화
    preprocessing.remove_rename(args)

    # save_path check
    utils.check_directory(args.save_path)

    # read files & sorting
    annotation_files = os.listdir(args.annotation_path)
    images_files = os.listdir(args.image_path)

    annotation_files_sort = sorted(annotation_files)
    images_files_sort = sorted(images_files)
    assert (len(annotation_files_sort) != len(images_files_sort),
            '파일 개수가 맞지 않음 anno : {0}, images : {1}'.format(
                len(annotation_files), len(images_files)))

    # start
    print('Crop start')
    start = time.time()  # 시작 시간 저장
    crop_image_count = 0
    for i in range(len(images_files)):
        annotation_file = utils.tag_remove_parser(annotation_files_sort[i])
        images_file = utils.tag_remove_parser(images_files_sort[i])

        # .DS_Store : mac에서 발생하는 os 오류.
        if (annotation_file != images_file or annotation_file == '.DSStore'
                or images_file == '.DSStore'):
            print('파일명이 일치 하지 않음 {0} 번째 파일'.format(i))
            continue
        # read xml, image files
        tree = parse(
            os.path.join(
                args.annotation_path,
                annotation_file + '.xml',
            ))
        origin_image = Image.open(
            os.path.join(args.image_path, images_file + '.jpg'))
        # read xml
        root = tree.getroot()
        # Find first tag
        elements = root.findall("object")
        # Get Class name
        names = [x.findtext("name") for x in elements]
        # Get annotation
        xmin_list = []
        ymin_list = []
        xmax_list = []
        ymax_list = []
        for element in elements:
            # xml -> object -> bndbox -> [xmin, ymin, xmax, ymax]
            xmin_list.append(int(element.find('bndbox').find('xmin').text))
            xmax_list.append(int(element.find('bndbox').find('xmax').text))
            ymin_list.append(int(element.find('bndbox').find('ymin').text))
            ymax_list.append(int(element.find('bndbox').find('ymax').text))
        # image crop & save
        for i, name in enumerate(names):
            bndbox_area = (xmin_list[i], ymin_list[i], xmax_list[i],
                           ymax_list[i])
            crop_image = origin_image.crop(bndbox_area)
            crop_image.save(
                os.path.join(args.save_path,
                             '{0}_{1}_{2}.jpg'.format(images_file, name, i)))
            # image generate counting
            crop_image_count += 1
    print('Crop end')
    print('생성된 이미지 수 :', crop_image_count)
    print("Crop time :", time.time() - start)

    print('File move start')
    start = time.time()  # 시작 시간 저장
    utils.move_files(args.save_path)
    print("Move time :", time.time() - start)
    print('File move end')
コード例 #13
0
#This script is to simulate wedged MLLs according to Andrejczuks paper
#"Influence of imperfections in a wedged MLL ..."
#This is the main script to execute
import config as cf
import matplotlib.pyplot as plt
import build_setup as bs
import utils as ut
################################################
#SINGLE SCAN
################################################
if cf.scanmode == "single" or cf.scanmode == "Single" or cf.scanmode == "s":
    #---------------------------------------------------------------------
    ###################################
    #Preparations
    ###################################
    ut.check_directory(cf.save_directory)
    ut.check_lenstype()
    print("Making the samples...")
    #Making the samples
    if cf.mk_slit == True:
        print("Making slit...")
        slits_pre_result = bs.b_slit()  #make the slit sample
        slits = slits_pre_result[0]
        stepslit = slits_pre_result[1]
    #Making vacuum
    print("Making vacuum...")
    opt_const_vac = bs.b_vac()
    #Now making the mll if flat mll
    if cf.mll_type == "flat":
        print("Making flat mll...")
        mll_pre_result = bs.b_mll()
コード例 #14
0
def main():

    parser = argparse.ArgumentParser()
    # file path argument
    parser.add_argument("--img_dir",
                        type=str,
                        default='data/images/IMX219-83',
                        help="directory of images")
    parser.add_argument("--take_img",
                        type=str,
                        default='False',
                        help="Take depth images")
    parser.add_argument("--img_ext",
                        type=str,
                        default='.jpg',
                        help="Images extention")
    parser.add_argument("--show_plt",
                        type=str,
                        default='False',
                        help="Show plots")

    # create argument object
    args = parser.parse_args()

    camera_data = utils.read_json('config/jetson/config.json')

    left_pipeline = camera_data['waveshare_camera']['left']
    right_pipeline = camera_data['waveshare_camera']['right']
    api = camera_data['waveshare_camera']['apiEnum']

    utils.check_directory(args.img_dir)
    if args.take_img == 'True':
        captureImage(left_pipeline,
                     api=api,
                     save_dir=args.img_dir,
                     img_name='img_l',
                     show_img=False)
        captureImage(right_pipeline,
                     api=api,
                     save_dir=args.img_dir,
                     img_name='img_r',
                     show_img=False)

    # Read the stereo-pair of images
    images = utils.files_in_dir(args.img_dir, args.img_ext)
    assert len(images) == 2
    img_left = cv2.cvtColor(cv2.imread(images[0]), cv2.COLOR_RGB2BGR)
    img_right = cv2.cvtColor(cv2.imread(images[1]), cv2.COLOR_RGB2BGR)

    if args.show_plt == 'True':
        # Large plot of the left image
        plt.figure(figsize=(10, 10), dpi=100)
        plt.imshow(img_left)
        plt.show()

    disp_left = stereo.compute_left_disparity_map(img_left, img_right)

    if args.show_plt == 'True':
        # Show the left disparity map
        plt.figure(figsize=(10, 10))
        plt.imshow(disp_left)
        plt.show()
コード例 #15
0
ファイル: analyze.py プロジェクト: tianyi21/scdr
def plot_embedding(embedding, assignment=None, label=None, batch_correction=None,
                   title=None, epoch=None, show=False, dr_type="TSNE", anno=False):
    """
    Visualize embedding
    :param embedding: embedding
    :param assignment: (Optional) clustering assignment
    :param label: (Optional) ground truth assignment
    :param title: (Optional) plot title
    :param epoch: (Optional) visualize epoch
    :param show: show plots
    :param dr_type: embedding type
    :return: none
    """

    def _2d_plot(embedding, data, title, anno=False):
        unique_data = np.unique(data)
        cmap = plt.get_cmap("Spectral")
        colors = cmap(np.linspace(0, 1.0, len(unique_data)))
        for item, color in zip(unique_data, colors):
            plt.scatter(embedding[data == item, 0], embedding[data == item, 1], s=1, label=item, c=[color])
        if anno:
            for item, txt in enumerate(data):
                if item % 50 == 0:
                    plt.annotate(txt, (embedding[item, 0], embedding[item, 1]))
        plt.title(title)
        plt.legend(loc="upper right")

    embedding = np.array(embedding)
    if embedding.shape[1] != 2:
        print("!!! 2D embedding expected for visualization while {}D provided".format(embedding.shape[1]))

    check_directory(VISUL_DIR)
    ext = "" if epoch is None else " Epoch " + str(epoch + 1)

    # simple plot
    if assignment is None and label is None:
        plt.scatter(embedding[:, 0], embedding[:, 1], s=1)
        plt.title(dr_type + " Embedding of {}".format(title))
        plt.legend(loc="upper right")
        plt.savefig(os.path.join(VISUL_DIR, dr_type.lower() + "_{}.pdf".format(title)), dpi=400)
    
    # plot clustering results without label
    elif assignment is not None and label is None:
        _2d_plot(embedding, assignment, "Cluster Assignment{}".format(ext), anno)
        plt.savefig(os.path.join(VISUL_DIR, dr_type.lower() + "_cls{}.pdf".format(ext.replace(" ", "_").lower())), dpi=400)

    # plot labels
    elif assignment is None and label is not None and batch_correction is None:
        _2d_plot(embedding, label, "Labels", anno)
        plt.savefig(os.path.join(VISUL_DIR, dr_type.lower() + "_label{}.pdf".format(ext.replace(" ", "_").lower())), dpi=400)
    
    # plot labels with batch correction
    elif assignment is None and label is not None and batch_correction is not None:
        plt.figure(figsize=(15,8))
        plt.subplot(121)
        _2d_plot(embedding, label, "Labels", anno)
        plt.subplot(122)
        _2d_plot(embedding, batch_correction, "Batch Correction", False)
        plt.tight_layout()
        plt.savefig(os.path.join(VISUL_DIR, dr_type.lower() + "_bc_label{}.pdf".format(ext.replace(" ", "_").lower())), dpi=400)
    
    # plot clustering results with labels
    elif assignment is not None and label is not None and batch_correction is None:
        plt.figure(figsize=(15, 8))
        plt.subplot(121)
        _2d_plot(embedding, assignment, "Cluster Assignment{}".format(ext), anno)
        plt.subplot(122)
        _2d_plot(embedding, label, "Labels", anno)
        plt.tight_layout()
        plt.savefig(os.path.join(VISUL_DIR, dr_type.lower() + "_cls_label{}.pdf".format(ext.replace(" ", "_").lower())), dpi=400)
    
    # plot clustering results with labels and batch correction
    elif assignment is not None and label is not None and batch_correction is not None:
        plt.figure(figsize=(21, 8))
        plt.subplot(131)
        _2d_plot(embedding, assignment, "Cluster Assignment{}".format(ext), anno)
        plt.subplot(132)
        _2d_plot(embedding, batch_correction, "Batch Correction{}".format(ext), False)
        plt.subplot(133)
        _2d_plot(embedding, label, "Labels", anno)
        plt.tight_layout()
        plt.savefig(os.path.join(VISUL_DIR, dr_type.lower() + "_cls_bc_label{}.pdf".format(ext.replace(" ", "_").lower())), dpi=400)

    if show:
        plt.show()
    plt.clf()
コード例 #16
0
synched_pop: calls the pop method on all the elements of a dictionary
    at the same position.
main: The main function with the functionality.
"""

import os
import sys
import glob

from brainvision_wrangling import BrainvisionWrapper
from utils import check_directory
from progress.bar import Bar

from segment_data import synched_pop

data_folder = check_directory(sys.argv[1])
segment_folder = check_directory(sys.argv[2])


def facecode(filename: str) -> int:
    """ Translates the string facestim id to the numerical ones """
    faceid = filename[1]
    if filename[2] == "n":
        scr = "1"
    elif filename[2] == "s":
        scr = "2"
    else:
        scr = "0"
    facecode = scr + faceid
    return int(facecode)
parser.set_defaults(test=False)
parser.add_argument('--limit',
                    type=int,
                    default=0,
                    help='limit the size of whole data set')
parser.add_argument('--restore',
                    dest='restore',
                    action='store_true',
                    help='Reload the saved model')
parser.set_defaults(restore=False)
args = parser.parse_args()

torch.manual_seed(args.seed)
random.seed(args.seed)

check_directory(args.save)
# Read data
my_lang, document_list = utils.build_lang(args.data)
random.shuffle(document_list)
if args.limit != 0:
    document_list = document_list[:args.limit]
cut = int(len(document_list) * args.validation_p)
training_data, validation_data = \
        document_list[cut:], document_list[:cut]
# Test mode
if args.test:
    # Load last model
    number = torch.load(os.path.join(args.save, 'checkpoint.pt'))
    encoder = torch.load(
        os.path.join(args.save, 'encoder' + str(number) + '.pt'))
    context = torch.load(