示例#1
0
def save_outputs(net, output_path, test_ids, patch_sizes, strides):
    if not os.path.exists(output_path):
        os.makedirs(output_path)

    def model(img):
        outputs = net(img)
        pred = F.softmax(outputs, dim=1)
        return pred

    for test_id in test_ids:
        img_path = os.path.join(INPUT_DIR, IMAGES_DIR, test_id+'.tif')
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        for patch_size, stride in zip(patch_sizes, strides):
            if patch_size is None:
                padded_img, pads = helper.pad(img)
                padded_img = transforms.ToTensor()(padded_img)
                padded_img = transforms.Normalize(IMAGES_MEAN, IMAGES_STD)(padded_img)
                padded_img = torch.unsqueeze(padded_img, 0)
                padded_img = torch.tensor(padded_img, dtype=torch.float).cuda()
                pred = np.squeeze(model(padded_img).data.cpu().numpy())
                pred = np.moveaxis(pred, 0, -1)
                pred = helper.unpad(pred, pads)
                pred = np.moveaxis(pred, -1, 0)
                pred_path = 'full_size'
            else :
                pred = predict(model, img, patch_size, patch_size, stride, stride, normalize_img=True)
                pred_path = str(patch_size) + '_' + str(stride)

            pred_path = os.path.join(output_path, pred_path)
            if not os.path.exists(pred_path):
                os.makedirs(pred_path)
            pred_path = os.path.join(pred_path, test_id)
            np.save(pred_path, pred)
示例#2
0
文件: tester.py 项目: AlvaroHYM/seq
    def eval_loader(self, loader):
        """
			Evaluate over a specific dataloader

		Args:
			loader: torch.DataLoader instance
		"""
        return predict(model=self.model,
                       pipeline=self.pipeline,
                       dataloader=loader,
                       task=self.task,
                       mode="eval")
示例#3
0
def do_prediction(net,
                  output_path,
                  test_ids,
                  patch_size,
                  stride,
                  post_processing,
                  dilation,
                  image_scales_number,
                  visualize=False):
    if not os.path.exists(output_path):
        os.makedirs(output_path)
        os.makedirs(os.path.join(output_path, LABELS_DIR))

    def model(img):
        outputs = net(img)[0]
        # outputs = torch.cat(outputs, dim=0)
        # outputs = torch.mean(outputs, dim=0, keepdim=True)
        pred = F.softmax(outputs, dim=1)
        return pred

    sum_agg_jac = 0
    sum_dice = 0
    for test_id in test_ids:
        img_path = os.path.join(INPUT_DIR, IMAGES_DIR, test_id + '.tif')
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        pred = predict(model,
                       img,
                       patch_size,
                       patch_size,
                       stride,
                       stride,
                       normalize_img=True,
                       image_scales_number=image_scales_number)
        pred_labels = post_processing(pred, dilation=dilation)
        num_labels = np.max(pred_labels)
        colored_labels = \
            skimage.color.label2rgb(pred_labels, colors=helper.get_spaced_colors(num_labels)).astype(np.uint8)
        pred_labels_path = os.path.join(output_path, LABELS_DIR, test_id)
        pred_colored_labels_path = os.path.join(output_path, test_id + '.png')
        np.save(pred_labels_path, pred_labels)
        sio.savemat(pred_labels_path + '.mat', {'predicted_map': pred_labels},
                    do_compression=True)
        bgr_labels = cv2.cvtColor(colored_labels, cv2.COLOR_RGB2BGR)
        cv2.imwrite(pred_colored_labels_path, bgr_labels)

        if visualize:
            plt.imshow(img)
            plt.imshow(colored_labels, alpha=0.5)
            centroids = np.argmax(pred, axis=0) == 3
            plt.imshow(centroids, alpha=0.5)
            plt.show()
            cv2.waitKey(0)

        labels_path = os.path.join(INPUT_DIR, LABELS_DIR, test_id + '.npy')
        gt_labels = np.load(labels_path).astype(np.int)
        agg_jac = aggregated_jaccard(pred_labels, gt_labels)
        sum_agg_jac += agg_jac
        print('{}\'s Aggregated Jaccard Index: {:.4f}'.format(
            test_id, agg_jac))

        mask_path = os.path.join(INPUT_DIR, MASKS_DIR, test_id + '.png')
        gt_mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) / 255
        pred_mask = skmorph.dilation(np.argmax(pred, axis=0) >= 2, skmorph.disk(dilation)) if dilation is not None else \
                    np.argmax(pred, axis=0) >= 2
        dice = dice_index(pred_mask, gt_mask)
        sum_dice += dice
        print('{}\'s Dice Index: {:.4f}'.format(test_id, dice))

    print('--------------------------------------')
    print('Mean Aggregated Jaccard Index: {:.4f}'.format(sum_agg_jac /
                                                         len(test_ids)))
    print('Mean Dice Index: {:.4f}'.format(sum_dice / len(test_ids)))
示例#4
0
def do_prediction(net,
                  output_path,
                  test_ids,
                  patch_size,
                  stride,
                  post_processing,
                  labeling,
                  visualize=False,
                  in_channels=3):
    if not os.path.exists(output_path):
        os.makedirs(output_path)
        os.makedirs(os.path.join(output_path, LABELS_DIR))

    def model(img):
        outputs = net(img)
        pred = F.log_softmax(outputs, dim=1)
        pred = (pred.argmax(dim=1, keepdim=True) == 1)
        return pred

    sum_agg_jac = 0
    sum_dice = 0
    for test_id in test_ids:
        img_path = os.path.join(INPUT_DIR, IMAGES_DIR, test_id + '.tif')
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        if in_channels > 3:
            pred_mask_path = os.path.join(INPUT_DIR, PRED_MASKS_DIR,
                                          test_id + '.png')
            pred_mask = cv2.imread(pred_mask_path, cv2.IMREAD_GRAYSCALE)
            pred_mask = np.expand_dims(pred_mask, -1)

            img = np.concatenate((img, pred_mask), axis=-1)

        pred = predict(model, img, patch_size, patch_size, stride, stride)
        pred_labels = post_processing(pred)
        io.imsave(os.path.join(output_path, test_id + '.png'),
                  pred_labels * 255)
        if labeling:
            num_labels = np.max(pred_labels)
            colored_labels = \
                skimage.color.label2rgb(pred_labels, colors=helper.get_spaced_colors(num_labels)).astype(np.uint8)
            pred_labels_path = os.path.join(output_path, LABELS_DIR, test_id)
            pred_colored_labels_path = os.path.join(output_path,
                                                    test_id + '.png')
            np.save(pred_labels_path, pred_labels)
            bgr_labels = cv2.cvtColor(colored_labels, cv2.COLOR_RGB2BGR)
            cv2.imwrite(pred_colored_labels_path, bgr_labels)

        if visualize:
            plt.imshow(img)
            if labeling:
                plt.imshow(colored_labels, alpha=0.5)
            else:
                plt.imshow(pred_labels, alpha=0.5)
            plt.show()
            cv2.waitKey(0)

        if labeling:
            # colored_labels_path = os.path.join(INPUT_DIR, COLORED_LABELS_DIR, test_id+'.png')
            # labels_img = cv2.imread(colored_labels_path)
            # gt_labels = helper.rgb2label(labels_img)
            labels_path = os.path.join(INPUT_DIR, LABELS_DIR, test_id + '.npy')
            gt_labels = np.load(labels_path)
            agg_jac = aggregated_jaccard(pred_labels, gt_labels)
            sum_agg_jac += agg_jac
            print('{}\'s Aggregated Jaccard Index: {:.4f}'.format(
                test_id, agg_jac))

        mask_path = os.path.join(INPUT_DIR, INSIDE_MASKS_DIR, test_id + '.png')
        gt_mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) / 255
        dice = dice_index(pred_labels, gt_mask)
        sum_dice += dice
        print('{}\'s Dice Index: {:.4f}'.format(test_id, dice))

    print('--------------------------------------')
    if labeling:
        print('Mean Aggregated Jaccard Index: {:.4f}'.format(sum_agg_jac /
                                                             len(test_ids)))
    print('Mean Dice Index: {:.4f}'.format(sum_dice / len(test_ids)))
示例#5
0
def app():
    st.title('Traffic Speed Prediction')
    st.write("> Let's try predicting some traffic speed!")
    current_dir = dirname(dirname(abspath(__file__)))
    d = dirname(dirname(abspath(__file__)))
    print(d)
    dd = dirname(dirname(dirname(abspath(__file__))))
    print(dd)
    ddd = os.path.join(dd, 'truncated_data')
    print(ddd)

    # Sidebar --------------------------------------------------------------------------------
    st.sidebar.title("Model Options")
    st.sidebar.write("Select your prediction options below")
    st.sidebar.write(
        "> There are a few components to this model to indicate how much past data you want to input, and how far in the future you want to predict until."
    )

    # Input Timesteps
    st.sidebar.subheader("Select Input Timesteps")
    st.sidebar.write(
        "How much past data to input into the model for prediction")
    # input_timestep_options = {1: "1 (5 minutes)", 2: "2 (10 minutes)", 3: "3 (15 minutes) - optimal"}
    input_timestep_options = {8: "8 (40 minutes) - default"}
    num_input_timesteps = st.sidebar.selectbox(
        "Number of Input Timesteps",
        options=list(input_timestep_options.keys()),
        format_func=lambda x: input_timestep_options[x])

    # Output Timesteps
    st.sidebar.subheader("Select Output Timesteps")
    st.sidebar.write("How far do you want to predict the traffic speeds")
    output_timestep_options = {
        1: "1 (5 minutes)",
        2: "2 (10 minutes)",
        3: "3 (15 minutes)",
        4: "4 (20 minutes)"
    }
    num_output_timesteps = st.sidebar.selectbox(
        "Number of Output Timesteps",
        options=list(output_timestep_options.keys()),
        format_func=lambda x: output_timestep_options[x],
        index=3)

    # -----------------------------------------------------------------------------------

    # Main Content --------------------------------------------------------------------------------
    sample_zip_path = os.path.join(current_dir, 'data', 'sample', 'input.zip')

    st.write("### 1. Download Sample Input Files")
    st.write(
        "Here's a sample input file with the format that is required for the model prediction. You can download this, change the data and upload the zip file below."
    )
    download_local_button(sample_zip_path, 'input.zip', 'Download files')
    st.write("___________________________")
    st.write("### 2. Upload Input Files")
    st.write("Please upload the zip file with the correct format below")
    zip_file = st.file_uploader("Upload file", type="zip")
    if zip_file is not None:
        # file_details = {'file_name': zip_file.name, 'file_type': zip_file.type}

        # Saving File
        saved_zip_path = os.path.join(current_dir, 'data', zip_file.name)
        with open(saved_zip_path, 'wb') as f:
            f.write(zip_file.getbuffer())

        with ZipFile(saved_zip_path, 'r') as zip:
            # printing all the contents of the zip file
            zip.printdir()
            # extracting all the files
            print('Extracting all the files now...')
            unzip_path = os.path.join(current_dir, 'data', 'raw')
            zip.extractall(path=unzip_path)
            print('Done!')

        st.success('File Uploaded! You can now predict traffic speeds')

        # Predict Traffic Speeds here
        if st.button("Predict Traffic Speeds", key='predict'):
            with st.spinner("Please wait for prediction results...."):
                st.write('## Results')
                results, A, X, metadata = predict(num_timesteps_input=8,
                                                  num_timesteps_output=4)

                # Display Metadata
                st.write('#### Metadata')
                metadata_expander = st.beta_expander("Click to expand",
                                                     expanded=False)
                with metadata_expander:
                    st.write(
                        "Here's the metadata of the input data you have uploaded"
                    )
                    df = pd.DataFrame(metadata).transpose()
                    st.write(df)
                    download_button(df, 'metadata.csv', 'Download metadata')

                # Display Results
                st.write('#### Predictions')
                predictions_expander = st.beta_expander("Click to expand",
                                                        expanded=False)
                with predictions_expander:

                    def loc_to_linestring(loc):
                        coordArr = loc.split()
                        coordArr = [float(coord) for coord in coordArr]
                        return LineString([coordArr[1::-1], coordArr[3:1:-1]])

                    def plotGeoPerformance(metadata, speedbands):
                        df = pd.DataFrame(metadata).transpose()
                        df["speedbands"] = speedbands
                        loc = df["start_pos"] + " " + df["end_pos"]
                        linestrings = loc.apply(loc_to_linestring)
                        gdf = gpd.GeoDataFrame(df,
                                               geometry=linestrings,
                                               crs="EPSG:4326")
                        gdf = gdf.to_crs('EPSG:3857')
                        fig, ax = plt.subplots(figsize=(10, 10))
                        gdf.plot(ax=ax,
                                 column="speedbands",
                                 legend=True,
                                 cmap="OrRd",
                                 legend_kwds={'label': 'speedbands'})
                        ax.set_xlabel("Longitude")
                        ax.set_ylabel("Latitude")
                        ctx.add_basemap(ax)
                        plt.savefig("currentPrediction.png")

                    timestep_speedbands = results.reshape(
                        results.shape[2], results.shape[1])
                    plotGeoPerformance(
                        metadata,
                        timestep_speedbands[num_output_timesteps - 1])

                    st.write(
                        "Below is a graph of the predicted traffic speedbands plotted on the roads of this geographical map. The colours of the roads represent the varying speedband numbers."
                    )
                    st.image("currentPrediction.png")
                    st.write(
                        "Below is a table of the predicted traffic speedbands for the respective roads. Please refer to the metadata above for the index mappings"
                    )
                    results = results[:, :, num_output_timesteps - 1]
                    results = pd.DataFrame(results)
                    st.write(results)
                    download_button(results, 'predictions.csv',
                                    'Download predictions')
示例#6
0
def main(args, logger):

    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)

    all_tasks = [args.task_name] + args.cont_task_names

    # check for NAN values and end experiment promptly
    torch.autograd.set_detect_anomaly(True)

    if args.do_train:

        loading_path = args.load_ckpt

        for i, task in enumerate(all_tasks):

            if task.lower() not in PROCESSORS:
                raise ValueError("Task not found: %s" % (task.lower()))
            logger.info('*** Start training for {} ***'.format(task))

            processor = PROCESSORS[task.lower()](args.num_ex)
            num_labels = NUM_LABELS_TASK[task.lower()]
            task_type = TASK_TYPE[task.lower()]
            label_list = None
            if task_type != 1:
                label_list = processor.get_labels()

            # make output_dir
            os.makedirs(os.path.join(args.output_dir, task), exist_ok=False)

            # init tensorboard writer
            tensorboard_writer = SummaryWriter(os.path.join(
                args.log_dir, task))

            train_examples = processor.get_train_examples(
                os.path.join(args.data_dir, task))
            num_train_steps = int(
                len(train_examples) / args.train_batch_size /
                args.gradient_accumulation_steps * args.num_train_epochs)

            if args.do_eval:
                # prepare eval data
                eval_dataloader = prepare_data_loader(args,
                                                      processor,
                                                      label_list,
                                                      task_type,
                                                      task,
                                                      tokenizer,
                                                      split='dev')[0]

            # Prepare model
            opt = {
                'bidirect': args.bidirect,
                'sub_word_masking': args.sub_word_masking,
                'nRoles': args.nRoles,
                'nSymbols': args.nSymbols,
                'dRoles': args.dRoles,
                'dSymbols': args.dSymbols,
                'encoder': args.encoder,
                'fixed_Role': args.fixed_Role,
                'scale_val': args.scale_val,
                'train_scale': args.train_scale,
                'aggregate': args.aggregate,
                'freeze_bert': args.freeze_bert,
                'num_rnn_layers': args.num_rnn_layers,
                'num_extra_layers': args.num_extra_layers,
                'num_heads': args.num_heads,
                'do_src_mask': args.do_src_mask,
                'ortho_reg': args.ortho_reg,
                'inductive_reg': args.inductive_reg,
                'cls': args.cls
            }
            logger.info('*' * 50)
            logger.info('option for training: {}'.format(args))
            logger.info('*' * 50)
            # also print it for philly debugging
            print('option for training: {}'.format(args))

            model, bert_config = prepare_model(args, opt, num_labels,
                                               task_type, device, n_gpu,
                                               loading_path)

            print(
                'num_elems:',
                sum([
                    p.nelement() for p in model.parameters() if p.requires_grad
                ]))

            # Prepare optimizer
            optimizer, scheduler, t_total = prepare_optim(
                args,
                num_train_steps,
                param_optimizer=list(model.named_parameters()))

            global_step = 0
            best_eval_accuracy = -float('inf')
            best_model = None
            last_update = 0

            train_dataloader = prepare_data_loader(args,
                                                   processor,
                                                   label_list,
                                                   task_type,
                                                   task,
                                                   tokenizer,
                                                   split='train')[0]

            for epoch in trange(int(args.num_train_epochs), desc="Epoch"):
                model.train()
                tr_loss = 0
                nb_tr_examples, nb_tr_steps = 0, 0
                for step, batch in enumerate(
                        tqdm(train_dataloader, desc="Iteration")):
                    batch = tuple(t.to(device) for t in batch)
                    input_ids, input_mask, segment_ids, sub_word_masks, orig_to_token_maps, label_ids = batch
                    _, loss, _ = model(input_ids, segment_ids, input_mask,
                                       sub_word_masks, label_ids)
                    if n_gpu > 1:
                        loss = loss.mean()  # mean() to average on multi-gpu.
                    if args.gradient_accumulation_steps > 1:
                        loss = loss / args.gradient_accumulation_steps

                    if args.fp16:
                        optimizer.backward(loss)
                    else:
                        loss.backward()

                        if args.debug:
                            print('\n')
                            for name, value in model.named_parameters():
                                if value.requires_grad and value.grad is not None:
                                    print('{}: {}'.format(
                                        name, (torch.max(abs(value.grad)),
                                               torch.mean(abs(value.grad)),
                                               torch.min(abs(value.grad)))))

                    tr_loss += loss.item()
                    nb_tr_examples += input_ids.size(0)
                    nb_tr_steps += 1
                    if (step + 1) % args.gradient_accumulation_steps == 0:

                        # modify scaling factor
                        with torch.no_grad():
                            pre = model.module if hasattr(model,
                                                          'module') else model
                            if args.do_decay and hasattr(pre.head, 'scale'):
                                pre.head.scale.copy_(
                                    torch.tensor(decay(
                                        pre.head.scale.cpu().numpy(),
                                        args.mode, args.final_ratio,
                                        global_step, t_total),
                                                 dtype=pre.head.scale.dtype))

                        if args.debug:
                            pre = 'module' if hasattr(model, 'module') else ''

                            cls_w = dict(model.named_parameters())[
                                pre + 'classifier.weight'].clone()
                            cls_b = dict(model.named_parameters())[
                                pre + 'classifier.bias'].clone()

                        if args.optimizer == 'adam':
                            torch.nn.utils.clip_grad_norm_(
                                model.parameters(), args.max_grad_norm)
                            optimizer.step()
                            scheduler.step()
                        else:
                            optimizer.step()

                        if args.debug:
                            new_cls_w = dict(model.named_parameters())[
                                pre + 'classifier.weight'].clone()
                            new_cls_b = dict(model.named_parameters())[
                                pre + 'classifier.bias'].clone()

                            print('diff weight is: {}'.format(new_cls_w -
                                                              cls_w))
                            print('diff bias is: {}'.format(new_cls_b - cls_b))

                        optimizer.zero_grad()
                        global_step += 1

                    if (global_step % args.log_every
                            == 0) and (global_step != 0):
                        tensorboard_writer.add_scalar('train/loss', tr_loss,
                                                      global_step)

                # Save a trained model after each epoch
                if not args.save_best_only or not args.do_eval:
                    model_to_save = model.module if hasattr(
                        model,
                        'module') else model  # Only save the model it-self
                    output_model_file = os.path.join(*[
                        args.output_dir, task, "pytorch_model_{}.bin".format(
                            epoch)
                    ])
                    loading_path = output_model_file
                    logger.info(
                        "Saving checkpoint pytorch_model_{}.bin to {}".format(
                            epoch, output_model_file))
                    torch.save(
                        {
                            'state_dict': model_to_save.state_dict(),
                            'options': opt,
                            'bert_config': bert_config
                        }, output_model_file)

                if args.do_eval:
                    # evaluate model after every epoch
                    model.eval()
                    result, _ = evaluate(args, model, eval_dataloader, device,
                                         task_type, global_step, tr_loss,
                                         nb_tr_steps)
                    for key in sorted(result.keys()):
                        if key == 'eval_loss':
                            tensorboard_writer.add_scalar(
                                'eval/loss', result[key], global_step)
                        elif key == 'eval_accuracy':
                            tensorboard_writer.add_scalar(
                                'eval/accuracy', result[key], global_step)
                        logger.info("  %s = %s", key, str(result[key]))

                    if result[
                            'eval_accuracy'] - best_eval_accuracy > args.tolerance:
                        last_update = epoch
                    if last_update + args.patience <= epoch:
                        logger.info(
                            "*** Model accuracy has not improved for {} consecutive epochs. Stopping training...***"
                            .format(args.patience))
                        break

                    if result['eval_accuracy'] >= best_eval_accuracy:
                        best_eval_accuracy = result['eval_accuracy']
                        best_model = model
                        # Save the best model
                        model_to_save = model.module if hasattr(
                            model,
                            'module') else model  # Only save the model it-self
                        output_model_file = os.path.join(
                            *[args.output_dir, task, "pytorch_model_best.bin"])
                        loading_path = output_model_file
                        logger.info(
                            "Saving checkpoint pytorch_model_best.bin to {}".
                            format(output_model_file))
                        torch.save(
                            {
                                'state_dict': model_to_save.state_dict(),
                                'options': opt,
                                'bert_config': bert_config
                            }, output_model_file)

            # for continual learning
            if args.do_prev_eval:
                if best_model is None:
                    best_model = model
                best_model.eval()

                # evaluate best model on current task
                dev_task = task
                result, _ = evaluate(args, best_model, eval_dataloader, device,
                                     task_type)
                logger.info("train_task: {}, eval_task: {}".format(
                    task, dev_task))
                for key in sorted(result.keys()):
                    logger.info("  %s = %s", key, str(result[key]))

                # evaluate new model on all previous tasks
                pre = best_model.module if hasattr(best_model,
                                                   'module') else best_model
                for j in range(i):
                    dev_task = all_tasks[j]

                    with torch.no_grad():
                        modify_model(best_model, dev_task, args)

                    processor = PROCESSORS[dev_task.lower()](args.num_ex)
                    num_labels = NUM_LABELS_TASK[dev_task.lower()]
                    task_type = TASK_TYPE[dev_task.lower()]
                    label_list = None
                    if task_type != 1:
                        label_list = processor.get_labels()
                    pre.num_labels = num_labels
                    pre.task_type = task_type
                    eval_dataloader = prepare_data_loader(args,
                                                          processor,
                                                          label_list,
                                                          task_type,
                                                          dev_task,
                                                          tokenizer,
                                                          split='dev')[0]

                    result, _ = evaluate(args, best_model, eval_dataloader,
                                         device, task_type)
                    logger.info("train_task: {}, eval_task: {}".format(
                        task, dev_task))
                    for key in sorted(result.keys()):
                        logger.info("  %s = %s", key, str(result[key]))

            tensorboard_writer.close()

    if args.do_eval and (args.local_rank == -1
                         or torch.distributed.get_rank() == 0):

        eval_task_name = all_tasks[-1]
        logger.info('*** Start evaluating for {} ***'.format(eval_task_name))

        processor = PROCESSORS[eval_task_name.lower()](args.num_ex)
        num_labels = NUM_LABELS_TASK[eval_task_name.lower()]
        task_type = TASK_TYPE[eval_task_name.lower()]
        label_list = None
        if task_type != 1:
            label_list = processor.get_labels()

        # Load a trained model for evaluation
        if args.do_train:
            output_model_file = os.path.join(
                *[args.output_dir, all_tasks[-1], 'pytorch_model_best.bin'])
        else:
            output_model_file = os.path.join(args.load_ckpt)

        #prepare data
        split = args.data_split_attention if args.save_tpr_attentions else 'dev'

        return_mapping = {
            'pos': args.return_POS,
            'ner': args.return_NER,
            'dep_edge': args.return_DEP,
            'depth': args.return_CONST,
            'const': args.return_CONST
        }

        eval_dataloader, all_guids, structure_features = \
                            prepare_data_loader(args, processor, label_list, task_type, all_tasks[-1], tokenizer,
                                single_sentence=args.single_sentence, split=split, return_pos_tags=return_mapping['pos'],
                                return_ner_tags=return_mapping['ner'], return_dep_parse=return_mapping['dep_edge'],
                                return_const_parse=return_mapping['const'])

        all_tokens, token_pos, token_ner, token_dep, token_const = structure_features

        states = torch.load(output_model_file, map_location=device)
        model_state_dict = states['state_dict']
        opt = states['options']
        if 'nRoles' not in opt:
            for val in ['nRoles', 'nSymbols', 'dRoles', 'dSymbols']:
                opt[val] = getattr(args, val)

        bert_config = states['bert_config']
        if not isinstance(bert_config, PretrainedConfig):
            bert_dict = bert_config.to_dict()
            bert_dict['layer_norm_eps'] = 1e-12
            bert_config = PretrainedConfig.from_dict(bert_dict)

        if 'head.scale' in model_state_dict.keys():
            print('scale value is:', model_state_dict['head.scale'])
        logger.info('*' * 50)
        logger.info('option for evaluation: {}'.format(args))
        logger.info('*' * 50)
        # also print it for philly debugging
        print('option for evaluation: {}'.format(args))
        model = BertForSequenceClassification_tpr(
            bert_config,
            num_labels=num_labels,
            task_type=task_type,
            temperature=args.temperature,
            max_seq_len=args.max_seq_length,
            **opt)

        model.load_state_dict(model_state_dict, strict=True)

        if args.reset_temp_ratio != 1.0 and hasattr(model.head, 'temperature'):
            new_temp = model.head.temperature / args.reset_temp_ratio
            model.head.temperature = new_temp

        model.to(device)
        model.eval()
        result, (all_ids, F_list, R_list, F_full, R_full) = evaluate(
            args,
            model,
            eval_dataloader,
            device,
            task_type,
            data_split=split,
            save_tpr_attentions=args.save_tpr_attentions)

        if not os.path.exists(os.path.join(args.output_dir, eval_task_name)):
            os.makedirs(os.path.join(args.output_dir, eval_task_name))

        if (not args.save_tpr_attentions) and (args.do_tsne or args.do_Kmeans):
            logger.warning(
                'T-SNE and K-means will not be performed since attentions are not saved'
            )
            logger.warning(
                'Turn on save_tpr_attentions argument to save attentions')

        if args.save_tpr_attentions:
            output_attention_file = os.path.join(
                *[args.output_dir, eval_task_name, "tpr_attention.txt"])
            vals = prepare_structure_values(args, eval_task_name, all_ids,
                                            F_list, R_list, F_full, R_full,
                                            all_tokens, token_pos, token_ner,
                                            token_dep, token_const)
            if args.do_tsne:
                if return_mapping[args.tsne_label]:
                    perform_tsne(args, vals, args.tsne_label)
                else:
                    logger.warning(
                        'T-SNE can not be performed with tsne_label: "{}" since values for that role is not saved'
                        .format(args.tsne_label))

            logger.info(
                'saving tpr_attentions to {} '.format(output_attention_file))
            with open(output_attention_file, "w") as fp:
                json.dump(vals, fp)

        output_eval_file = os.path.join(
            *[args.output_dir, eval_task_name, "eval_results.txt"])
        logger.info("***** Eval results *****")
        logger.info("  eval output file is in {}".format(output_eval_file))
        with open(output_eval_file, "w") as writer:
            writer.write(
                'exp_{:s}_{:.3f}_{:.6f}_{:.0f}_{:.1f}_{:.0f}_{:.0f}_{:.0f}_{:.0f}\n'
                .format(eval_task_name, args.temperature, args.learning_rate,
                        args.train_batch_size, args.num_train_epochs,
                        args.dSymbols, args.dRoles, args.nSymbols,
                        args.nRoles))
            for key in sorted(result.keys()):
                logger.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))

    if args.do_test and (args.local_rank == -1
                         or torch.distributed.get_rank() == 0):

        test_task_name = all_tasks[-1]
        logger.info('*** Start testing for {} ***'.format(test_task_name))
        processor = PROCESSORS[test_task_name.lower()](args.num_ex)
        num_labels = NUM_LABELS_TASK[test_task_name.lower()]
        task_type = TASK_TYPE[test_task_name.lower()]
        label_list = None
        if task_type != 1:
            label_list = processor.get_labels()

        # Load a trained model for evaluation
        if args.do_train:
            output_model_file = os.path.join(
                *[args.output_dir, test_task_name, 'pytorch_model_best.bin'])
        else:
            output_model_file = os.path.join(args.load_ckpt)

        states = torch.load(output_model_file, map_location=device)
        model_state_dict = states['state_dict']
        opt = states['options']
        if 'nRoles' not in opt:
            print(args.nRoles)
            for val in ['nRoles', 'nSymbols', 'dRoles', 'dSymbols']:
                opt[val] = getattr(args, val)
        bert_config = states['bert_config']

        if not isinstance(bert_config, PretrainedConfig):
            bert_dict = bert_config.to_dict()
            bert_dict['layer_norm_eps'] = 1e-12
            bert_config = PretrainedConfig.from_dict(bert_dict)

        if 'head.scale' in model_state_dict.keys():
            print('scale value is:', model_state_dict['head.scale'])
        logger.info('*' * 50)
        logger.info('option for evaluation: {}'.format(args))
        logger.info('*' * 50)

        # also print it for philly debugging
        print('option for evaluation: {}'.format(args))
        model = BertForSequenceClassification_tpr(
            bert_config,
            num_labels=num_labels,
            task_type=task_type,
            temperature=args.temperature,
            max_seq_len=args.max_seq_length,
            **opt)

        model.load_state_dict(model_state_dict, strict=True)
        model.to(device)
        model.eval()

        if args.reset_temp_ratio != 1.0 and hasattr(model.head, 'temperature'):
            new_temp = model.head.temperature / args.reset_temp_ratio
            model.head.temperature = new_temp

        if test_task_name.lower().startswith('dnc'):
            test_examples = processor.get_all_examples(args.data_dir)
            # prepare test data
            for k in test_examples.keys():

                test_dataloader, all_guids = prepare_data_loader(
                    args,
                    processor,
                    label_list,
                    task_type,
                    test_task_name,
                    tokenizer,
                    split='test',
                    examples=test_examples)

                result = predict(args, model, test_dataloader, all_guids,
                                 device, task_type)

                if not os.path.exists(
                        os.path.join(args.output_dir, test_task_name)):
                    os.makedirs(os.path.join(args.output_dir, test_task_name))
                output_test_file = os.path.join(*[
                    args.output_dir, test_task_name,
                    "{}-test_predictions.txt".format(k)
                ])
                logger.info("***** Test predictions *****")
                logger.info(
                    "  test output file is in {}".format(output_test_file))
                with open(output_test_file, "w") as writer:
                    writer.write("index\tpredictions\n")
                    for id, pred in zip(result['input_ids'],
                                        result['predictions']):
                        writer.write("%s\t%s\n" % (id, label_list[pred]))

        else:
            # prepare test data
            test_dataloader, all_guids = prepare_data_loader(args,
                                                             processor,
                                                             label_list,
                                                             task_type,
                                                             test_task_name,
                                                             tokenizer,
                                                             split='test')[:2]

            result = predict(args, model, test_dataloader, all_guids, device,
                             task_type)

            if not os.path.exists(os.path.join(args.output_dir,
                                               test_task_name)):
                os.makedirs(os.path.join(args.output_dir, test_task_name))
            output_test_file = os.path.join(
                *[args.output_dir, test_task_name, "test_predictions.txt"])
            logger.info("***** Test predictions *****")
            logger.info("  test output file is in {}".format(output_test_file))
            with open(output_test_file, "w") as writer:
                writer.write("index\tpredictions\n")
                for id, pred in zip(result['input_ids'],
                                    result['predictions']):
                    if test_task_name.lower() == 'hans':
                        if pred == 2:
                            pred = 0  #consider neutral as non-entailment
                        writer.write("%s,%s\n" % (id, label_list[pred]))
                    elif task_type == 1:
                        writer.write("%s\t%s\n" % (id, pred))
                    else:
                        writer.write("%s\t%s\n" % (id, label_list[pred]))
                if test_task_name.lower() == 'snli':
                    writer.write("test_accuracy:\t%s\n" %
                                 (result['test_accuracy']))
示例#7
0
def do_prediction(net, output_path, test_ids, patch_size, stride, dilation, gate_image, masking,
                  post_processing, normalize_img, visualize=False):
    if not os.path.exists(output_path):
        os.makedirs(output_path)
        os.makedirs(os.path.join(output_path, LABELS_DIR))

    # def model(img, mask=None):
    #     if masking:
    #         outputs = net(img, mask)
    #     else:
    #         outputs = net(img)
    #     return F.sigmoid(outputs[-1])

    def model(img):
        outputs1, outputs2 = net(img)
        pred = F.log_softmax(outputs1, dim=1)
        inside_mask = (pred.argmax(dim=1, keepdim=True) == 1).float()
        centroids = F.sigmoid(outputs2)
        return torch.cat((inside_mask, centroids), 1)

    sum_agg_jac = 0
    sum_dice = 0
    for test_id in test_ids:
        img_path = os.path.join(INPUT_DIR, IMAGES_DIR, test_id+'.tif')
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        if gate_image or masking:
            mask_path = os.path.join(INPUT_DIR, MASKS_DIR, test_id+'.png')
            gt_mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) / 255

        if gate_image:
            img = img * np.repeat(gt_mask[:, :, np.newaxis], img.shape[-1], axis=2)

        pred = predict(model, img, patch_size, patch_size, stride, stride, normalize_img=normalize_img) if not masking else \
               predict(model, img, patch_size, patch_size, stride, stride, normalize_img=normalize_img, mask=gt_mask)
        pred_labels = post_processing(pred, dilation=dilation)
        num_labels = np.max(pred_labels)
        colored_labels = \
            skimage.color.label2rgb(pred_labels, colors=helper.get_spaced_colors(num_labels)).astype(np.uint8)
        pred_labels_path = os.path.join(output_path, LABELS_DIR, test_id)
        pred_colored_labels_path = os.path.join(output_path, test_id+'.png')
        np.save(pred_labels_path, pred_labels)
        sio.savemat(pred_labels_path + '.mat', {'predicted_map': pred_labels}, do_compression=True)
        bgr_labels = cv2.cvtColor(colored_labels, cv2.COLOR_RGB2BGR)
        cv2.imwrite(pred_colored_labels_path, bgr_labels)

        if visualize:
            plt.imshow(img)
            plt.imshow(colored_labels, alpha=0.5)
            centroids = pred[-1] >= 0.5
            plt.imshow(centroids, alpha=0.5)
            plt.show()
            cv2.waitKey(0)

        labels_path = os.path.join(INPUT_DIR, LABELS_DIR, test_id+'.npy')
        gt_labels = np.load(labels_path).astype(np.int)
        agg_jac = aggregated_jaccard(pred_labels, gt_labels)
        sum_agg_jac += agg_jac
        print('{}\'s Aggregated Jaccard Index: {:.4f}'.format(test_id, agg_jac))

        mask_path = os.path.join(INPUT_DIR, MASKS_DIR, test_id+'.png')
        gt_mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) / 255
        pred_mask = skmorph.dilation(pred[0], skmorph.disk(dilation)) if dilation is not None else pred[0]
        dice = dice_index(pred_mask, gt_mask)
        sum_dice += dice
        print('{}\'s Dice Index: {:.4f}'.format(test_id, dice))

    print('--------------------------------------')
    print('Mean Aggregated Jaccard Index: {:.4f}'.format(sum_agg_jac/len(test_ids)))
    print('Mean Dice Index: {:.4f}'.format(sum_dice/len(test_ids)))