Пример #1
0
    def train_epoch(self, scaler, epoch, model, dataset, dataloader, optimizer, prefix="train"):
        model.train()

        _timer = Timer()
        lossLogger = LossLogger()
        performanceLogger = build_evaluator(self.cfg, dataset)

        num_iters = len(dataloader)
        for i, sample in enumerate(dataloader):
            self.n_iters_elapsed += 1
            _timer.tic()
            self.run_step(scaler, model, sample, optimizer, lossLogger, performanceLogger, prefix)
            torch.cuda.synchronize()
            _timer.toc()

            if (i + 1) % self.cfg.N_ITERS_TO_DISPLAY_STATUS == 0:
                if self.cfg.local_rank == 0:
                    template = "[epoch {}/{}, iter {}/{}, lr {}] Total train loss: {:.4f} " "(ips = {:.2f})\n" "{}"
                    logger.info(
                        template.format(
                            epoch, self.cfg.N_MAX_EPOCHS - 1, i, num_iters - 1,
                            round(get_current_lr(optimizer), 6),
                            lossLogger.meters["loss"].value,
                                   self.batch_size * self.cfg.N_ITERS_TO_DISPLAY_STATUS / _timer.diff,
                            "\n".join(
                                ["{}: {:.4f}".format(n, l.value) for n, l in lossLogger.meters.items() if n != "loss"]),
                        )
                    )

        if self.cfg.TENSORBOARD and self.cfg.local_rank == 0:
            # Logging train losses
            [self.tb_writer.add_scalar(f"loss/{prefix}_{n}", l.global_avg, epoch) for n, l in lossLogger.meters.items()]
            performances = performanceLogger.evaluate()
            if performances is not None and len(performances):
                [self.tb_writer.add_scalar(f"performance/{prefix}_{k}", v, epoch) for k, v in performances.items()]

        if self.cfg.TENSORBOARD_WEIGHT and False:
            for name, param in model.named_parameters():
                layer, attr = os.path.splitext(name)
                attr = attr[1:]
                self.tb_writer.add_histogram("{}/{}".format(layer, attr), param, epoch)
Пример #2
0
def text_detect(text_detector, im):

    im_small, f, im_height, im_width = resize_im(im, Config.SCALE,
                                                 Config.MAX_SCALE)

    timer = Timer()
    timer.tic()
    text_lines = text_detector.detect(im_small)
    text_lines = draw_boxes(im_small, text_lines, f, im_height, im_width)
    print "Number of the detected text lines: %s" % len(text_lines)
    print "Detection Time: %f" % timer.toc()
    print "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~"

    return text_lines
Пример #3
0
def text_detect(text_detector, im, img_type):
    if img_type == "others":
        return [], 0

    im_small, f = resize_im(im, Config.SCALE, Config.MAX_SCALE)

    timer = Timer()
    timer.tic()
    text_lines = text_detector.detect(im_small)
    text_lines = text_lines / f  # project back to size of original image
    text_lines = refine_boxes(im, text_lines, expand_pixel_len = Config.DILATE_PIXEL,
                              pixel_blank = Config.BREATH_PIXEL, binary_thresh=Config.BINARY_THRESH)
    text_area_ratio = calc_area_ratio(text_lines, im.shape)
    print "Number of the detected text lines: %s" % len(text_lines)
    print "Detection Time: %f" % timer.toc()
    print "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~"

    if Config.DEBUG_SAVE_BOX_IMG:
        im_with_text_lines = draw_boxes(im, text_lines, is_display=False, caption=image_path, wait=False)
        if im_with_text_lines is not None:
            cv2.imwrite(image_path+'_boxes.jpg', im_with_text_lines)

    return text_lines, text_area_ratio
Пример #4
0
def eval_model(model,
               classes,
               bm,
               last_epoch=True,
               verbose=False,
               xls_sheet=None):
    print('Start evaluation...')
    since = time.time()

    device = next(model.parameters()).device

    was_training = model.training
    model.eval()

    dataloaders = []

    for cls in classes:
        image_dataset = GMDataset(cfg.DATASET_FULL_NAME, bm, cfg.EVAL.SAMPLES,
                                  cfg.PROBLEM.TEST_ALL_GRAPHS, cls,
                                  cfg.PROBLEM.TYPE)

        torch.manual_seed(cfg.RANDOM_SEED
                          )  # Fix fetched data in test-set to prevent variance

        dataloader = get_dataloader(image_dataset, shuffle=True)
        dataloaders.append(dataloader)

    recalls = []
    precisions = []
    f1s = []
    coverages = []
    pred_time = []
    objs = torch.zeros(len(classes), device=device)
    cluster_acc = []
    cluster_purity = []
    cluster_ri = []

    timer = Timer()

    prediction = []

    for i, cls in enumerate(classes):
        if verbose:
            print('Evaluating class {}: {}/{}'.format(cls, i, len(classes)))

        running_since = time.time()
        iter_num = 0

        pred_time_list = []
        obj_total_num = torch.zeros(1, device=device)
        cluster_acc_list = []
        cluster_purity_list = []
        cluster_ri_list = []
        prediction_cls = []

        for inputs in dataloaders[i]:
            if iter_num >= cfg.EVAL.SAMPLES / inputs['batch_size']:
                break
            if model.module.device != torch.device('cpu'):
                inputs = data_to_cuda(inputs)

            batch_num = inputs['batch_size']

            iter_num = iter_num + 1

            with torch.set_grad_enabled(False):
                timer.tick()
                outputs = model(inputs)
                pred_time_list.append(
                    torch.full((batch_num, ),
                               timer.toc() / batch_num))

            # Evaluate matching accuracy
            if cfg.PROBLEM.TYPE == '2GM':
                assert 'perm_mat' in outputs

                for b in range(outputs['perm_mat'].shape[0]):
                    perm_mat = outputs['perm_mat'][
                        b, :outputs['ns'][0][b], :outputs['ns'][1][b]].cpu()
                    perm_mat = perm_mat.numpy()
                    eval_dict = dict()
                    id_pair = inputs['id_list'][0][b], inputs['id_list'][1][b]
                    eval_dict['ids'] = id_pair
                    eval_dict['cls'] = cls
                    eval_dict['perm_mat'] = perm_mat
                    prediction.append(eval_dict)
                    prediction_cls.append(eval_dict)

                if 'aff_mat' in outputs:
                    pred_obj_score = objective_score(outputs['perm_mat'],
                                                     outputs['aff_mat'])
                    gt_obj_score = objective_score(outputs['gt_perm_mat'],
                                                   outputs['aff_mat'])
                    objs[i] += torch.sum(pred_obj_score / gt_obj_score)
                    obj_total_num += batch_num
            elif cfg.PROBLEM.TYPE in ['MGM', 'MGM3']:
                assert 'graph_indices' in outputs
                assert 'perm_mat_list' in outputs

                ns = outputs['ns']
                idx = -1
                for x_pred, (idx_src, idx_tgt) in \
                        zip(outputs['perm_mat_list'], outputs['graph_indices']):
                    idx += 1
                    for b in range(x_pred.shape[0]):
                        perm_mat = x_pred[
                            b, :ns[idx_src][b], :ns[idx_tgt][b]].cpu()
                        perm_mat = perm_mat.numpy()
                        eval_dict = dict()
                        id_pair = inputs['id_list'][idx_src][b], inputs[
                            'id_list'][idx_tgt][b]
                        eval_dict['ids'] = id_pair
                        if cfg.PROBLEM.TYPE == 'MGM3':
                            eval_dict['cls'] = bm.data_dict[id_pair[0]]['cls']
                        else:
                            eval_dict['cls'] = cls
                        eval_dict['perm_mat'] = perm_mat
                        prediction.append(eval_dict)
                        prediction_cls.append(eval_dict)

            else:
                raise ValueError('Unknown problem type {}'.format(
                    cfg.PROBLEM.TYPE))

            # Evaluate clustering accuracy
            if cfg.PROBLEM.TYPE == 'MGM3':
                assert 'pred_cluster' in outputs
                assert 'cls' in outputs

                pred_cluster = outputs['pred_cluster']
                cls_gt_transpose = [[] for _ in range(batch_num)]
                for batched_cls in outputs['cls']:
                    for b, _cls in enumerate(batched_cls):
                        cls_gt_transpose[b].append(_cls)
                cluster_acc_list.append(
                    clustering_accuracy(pred_cluster, cls_gt_transpose))
                cluster_purity_list.append(
                    clustering_purity(pred_cluster, cls_gt_transpose))
                cluster_ri_list.append(
                    rand_index(pred_cluster, cls_gt_transpose))

            if iter_num % cfg.STATISTIC_STEP == 0 and verbose:
                running_speed = cfg.STATISTIC_STEP * batch_num / (
                    time.time() - running_since)
                print('Class {:<8} Iteration {:<4} {:>4.2f}sample/s'.format(
                    cls, iter_num, running_speed))
                running_since = time.time()

        objs[i] = objs[i] / obj_total_num
        pred_time.append(torch.cat(pred_time_list))
        if cfg.PROBLEM.TYPE == 'MGM3':
            cluster_acc.append(torch.cat(cluster_acc_list))
            cluster_purity.append(torch.cat(cluster_purity_list))
            cluster_ri.append(torch.cat(cluster_ri_list))

        if verbose:
            if cfg.PROBLEM.TYPE != 'MGM3':
                bm.eval_cls(prediction_cls, cls, verbose=verbose)
            print('Class {} norm obj score = {:.4f}'.format(cls, objs[i]))
            print('Class {} pred time = {}s'.format(
                cls, format_metric(pred_time[i])))
            if cfg.PROBLEM.TYPE == 'MGM3':
                print('Class {} cluster acc={}'.format(
                    cls, format_metric(cluster_acc[i])))
                print('Class {} cluster purity={}'.format(
                    cls, format_metric(cluster_purity[i])))
                print('Class {} cluster rand index={}'.format(
                    cls, format_metric(cluster_ri[i])))

    if cfg.PROBLEM.TYPE == 'MGM3':
        result = bm.eval(prediction, classes[0], verbose=True)
        for cls in classes[0]:
            precision = result[cls]['precision']
            recall = result[cls]['recall']
            f1 = result[cls]['f1']
            coverage = result[cls]['coverage']

            recalls.append(recall)
            precisions.append(precision)
            f1s.append(f1)
            coverages.append(coverage)
    else:
        result = bm.eval(prediction, classes, verbose=True)
        for cls in classes:
            precision = result[cls]['precision']
            recall = result[cls]['recall']
            f1 = result[cls]['f1']
            coverage = result[cls]['coverage']

            recalls.append(recall)
            precisions.append(precision)
            f1s.append(f1)
            coverages.append(coverage)

    time_elapsed = time.time() - since
    print('Evaluation complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))

    model.train(mode=was_training)

    if xls_sheet:
        for idx, cls in enumerate(classes):
            xls_sheet.write(0, idx + 1, cls)
        xls_sheet.write(0, idx + 2, 'mean')

    xls_row = 1

    # show result
    if xls_sheet:
        xls_sheet.write(xls_row, 0, 'precision')
        xls_sheet.write(xls_row + 1, 0, 'recall')
        xls_sheet.write(xls_row + 2, 0, 'f1')
        xls_sheet.write(xls_row + 3, 0, 'coverage')
    for idx, (cls, cls_p, cls_r, cls_f1, cls_cvg) in enumerate(
            zip(classes, precisions, recalls, f1s, coverages)):
        if xls_sheet:
            xls_sheet.write(
                xls_row, idx + 1,
                '{:.4f}'.format(cls_p))  #'{:.4f}'.format(torch.mean(cls_p)))
            xls_sheet.write(
                xls_row + 1, idx + 1,
                '{:.4f}'.format(cls_r))  #'{:.4f}'.format(torch.mean(cls_r)))
            xls_sheet.write(
                xls_row + 2, idx + 1,
                '{:.4f}'.format(cls_f1))  #'{:.4f}'.format(torch.mean(cls_f1)))
            xls_sheet.write(xls_row + 3, idx + 1, '{:.4f}'.format(cls_cvg))
    if xls_sheet:
        xls_sheet.write(xls_row, idx + 2,
                        '{:.4f}'.format(result['mean']['precision'])
                        )  #'{:.4f}'.format(torch.mean(torch.cat(precisions))))
        xls_sheet.write(xls_row + 1, idx + 2, '{:.4f}'.format(
            result['mean']
            ['recall']))  #'{:.4f}'.format(torch.mean(torch.cat(recalls))))
        xls_sheet.write(xls_row + 2, idx + 2, '{:.4f}'.format(
            result['mean']
            ['f1']))  #'{:.4f}'.format(torch.mean(torch.cat(f1s))))
        xls_row += 4

    if not torch.any(torch.isnan(objs)):
        print('Normalized objective score')
        if xls_sheet: xls_sheet.write(xls_row, 0, 'norm objscore')
        for idx, (cls, cls_obj) in enumerate(zip(classes, objs)):
            print('{} = {:.4f}'.format(cls, cls_obj))
            if xls_sheet:
                xls_sheet.write(xls_row, idx + 1,
                                cls_obj.item())  #'{:.4f}'.format(cls_obj))
        print('average objscore = {:.4f}'.format(torch.mean(objs)))
        if xls_sheet:
            xls_sheet.write(
                xls_row, idx + 2,
                torch.mean(objs).item())  #'{:.4f}'.format(torch.mean(objs)))
            xls_row += 1

    if cfg.PROBLEM.TYPE == 'MGM3':
        print('Clustering accuracy')
        if xls_sheet: xls_sheet.write(xls_row, 0, 'cluster acc')
        for idx, (cls, cls_acc) in enumerate(zip(classes, cluster_acc)):
            print('{} = {}'.format(cls, format_metric(cls_acc)))
            if xls_sheet:
                xls_sheet.write(xls_row, idx + 1,
                                torch.mean(cls_acc).item()
                                )  #'{:.4f}'.format(torch.mean(cls_acc)))
        print('average clustering accuracy = {}'.format(
            format_metric(torch.cat(cluster_acc))))
        if xls_sheet:
            xls_sheet.write(
                xls_row, idx + 2,
                torch.mean(torch.cat(cluster_acc)).item(
                ))  #'{:.4f}'.format(torch.mean(torch.cat(cluster_acc))))
            xls_row += 1

        print('Clustering purity')
        if xls_sheet: xls_sheet.write(xls_row, 0, 'cluster purity')
        for idx, (cls, cls_acc) in enumerate(zip(classes, cluster_purity)):
            print('{} = {}'.format(cls, format_metric(cls_acc)))
            if xls_sheet:
                xls_sheet.write(xls_row, idx + 1,
                                torch.mean(cls_acc).item()
                                )  #'{:.4f}'.format(torch.mean(cls_acc)))
        print('average clustering purity = {}'.format(
            format_metric(torch.cat(cluster_purity))))
        if xls_sheet:
            xls_sheet.write(
                xls_row, idx + 2,
                torch.mean(torch.cat(cluster_purity)).item(
                ))  #'{:.4f}'.format(torch.mean(torch.cat(cluster_purity))))
            xls_row += 1

        print('Clustering rand index')
        if xls_sheet: xls_sheet.write(xls_row, 0, 'rand index')
        for idx, (cls, cls_acc) in enumerate(zip(classes, cluster_ri)):
            print('{} = {}'.format(cls, format_metric(cls_acc)))
            if xls_sheet:
                xls_sheet.write(xls_row, idx + 1,
                                torch.mean(cls_acc).item()
                                )  #'{:.4f}'.format(torch.mean(cls_acc)))
        print('average rand index = {}'.format(
            format_metric(torch.cat(cluster_ri))))
        if xls_sheet:
            xls_sheet.write(
                xls_row, idx + 2,
                torch.mean(torch.cat(cluster_ri)).item(
                ))  #'{:.4f}'.format(torch.mean(torch.cat(cluster_ri))))
            xls_row += 1

    print('Predict time')
    if xls_sheet: xls_sheet.write(xls_row, 0, 'time')
    for idx, (cls, cls_time) in enumerate(zip(classes, pred_time)):
        print('{} = {}'.format(cls, format_metric(cls_time)))
        if xls_sheet:
            xls_sheet.write(
                xls_row, idx + 1,
                torch.mean(
                    cls_time).item())  #'{:.4f}'.format(torch.mean(cls_time)))
    print('average time = {}'.format(format_metric(torch.cat(pred_time))))
    if xls_sheet:
        xls_sheet.write(xls_row, idx + 2,
                        torch.mean(torch.cat(pred_time)).item()
                        )  #'{:.4f}'.format(torch.mean(torch.cat(pred_time))))
        xls_row += 1

    bm.rm_gt_cache(last_epoch=last_epoch)

    return torch.Tensor(recalls)
Пример #5
0
    def train_epoch(self,
                    scaler,
                    epoch,
                    model,
                    dataloader,
                    optimizer,
                    prefix="train"):
        model.train()

        _timer = Timer()
        lossLogger = LossLogger()
        performanceLogger = MetricLogger(self.dictionary, self.cfg)

        for i, sample in enumerate(dataloader):
            imgs, targets = sample['image'], sample['target']
            _timer.tic()
            # zero the parameter gradients
            optimizer.zero_grad()

            imgs = list(
                img.cuda()
                for img in imgs) if isinstance(imgs, list) else imgs.cuda()
            if isinstance(targets, list):
                if isinstance(targets[0], torch.Tensor):
                    targets = [t.cuda() for t in targets]
                else:
                    targets = [{k: v.cuda()
                                for k, v in t.items()} for t in targets]
            else:
                targets = targets.cuda()

            # Autocast
            with amp.autocast(enabled=True):
                out = model(imgs, targets, prefix)

            if not isinstance(out, tuple):
                losses, predicts = out, None
            else:
                losses, predicts = out

            self.n_iters_elapsed += 1

            # Scales loss.  Calls backward() on scaled loss to create scaled gradients.
            # Backward passes under autocast are not recommended.
            # Backward ops run in the same dtype autocast chose for corresponding forward ops.
            scaler.scale(losses["loss"]).backward()
            # scaler.step() first unscales the gradients of the optimizer's assigned params.
            # If these gradients do not contain infs or NaNs, optimizer.step() is then called,
            # otherwise, optimizer.step() is skipped.
            scaler.step(optimizer)
            # Updates the scale for next iteration.
            scaler.update()

            # torch.cuda.synchronize()
            _timer.toc()

            if (i + 1) % self.cfg.N_ITERS_TO_DISPLAY_STATUS == 0:
                if self.cfg.distributed:
                    # reduce losses over all GPUs for logging purposes
                    loss_dict_reduced = reduce_dict(losses)
                    lossLogger.update(**loss_dict_reduced)
                    del loss_dict_reduced
                else:
                    lossLogger.update(**losses)

                if predicts is not None:
                    if self.cfg.distributed:
                        # reduce performances over all GPUs for logging purposes
                        predicts_dict_reduced = reduce_dict(predicts)
                        performanceLogger.update(targets,
                                                 predicts_dict_reduced)
                        del predicts_dict_reduced
                    else:
                        performanceLogger.update(**predicts)
                    del predicts

                if self.cfg.local_rank == 0:
                    template = "[epoch {}/{}, iter {}, lr {}] Total train loss: {:.4f} " "(ips = {:.2f})\n" "{}"
                    logger.info(
                        template.format(
                            epoch,
                            self.cfg.N_MAX_EPOCHS,
                            i,
                            round(get_current_lr(optimizer), 6),
                            lossLogger.meters["loss"].value,
                            self.batch_size *
                            self.cfg.N_ITERS_TO_DISPLAY_STATUS / _timer.diff,
                            "\n".join([
                                "{}: {:.4f}".format(n, l.value)
                                for n, l in lossLogger.meters.items()
                                if n != "loss"
                            ]),
                        ))

            del imgs, targets, losses

        if self.cfg.TENSORBOARD and self.cfg.local_rank == 0:
            # Logging train losses
            [
                self.tb_writer.add_scalar(f"loss/{prefix}_{n}", l.global_avg,
                                          epoch)
                for n, l in lossLogger.meters.items()
            ]
            performances = performanceLogger.compute()
            if len(performances):
                [
                    self.tb_writer.add_scalar(f"performance/{prefix}_{k}", v,
                                              epoch)
                    for k, v in performances.items()
                ]

        if self.cfg.TENSORBOARD_WEIGHT and False:
            for name, param in model.named_parameters():
                layer, attr = os.path.splitext(name)
                attr = attr[1:]
                self.tb_writer.add_histogram("{}/{}".format(layer, attr),
                                             param, epoch)
Пример #6
0
    def train_epoch(self,
                    epoch,
                    model,
                    dataloader,
                    optimizer,
                    lr_scheduler,
                    grad_normalizer=None,
                    prefix="train"):
        model.train()

        _timer = Timer()
        lossMeter = LossMeter()
        perfMeter = PerfMeter()

        for i, (imgs, labels) in enumerate(dataloader):
            _timer.tic()
            # zero the parameter gradients
            optimizer.zero_grad()

            if self.cfg.HALF:
                imgs = imgs.half()

            if len(self.device) > 1:
                out = data_parallel(model, (imgs, labels, prefix),
                                    device_ids=self.device,
                                    output_device=self.device[0])
            else:
                imgs = imgs.cuda()
                labels = [label.cuda() for label in labels] if isinstance(
                    labels, list) else labels.cuda()
                out = model(imgs, labels, prefix)

            if not isinstance(out, tuple):
                losses, performances = out, None
            else:
                losses, performances = out

            if losses["all_loss"].sum().requires_grad:
                if self.cfg.GRADNORM is not None:
                    grad_normalizer.adjust_losses(losses)
                    grad_normalizer.adjust_grad(model, losses)
                else:
                    losses["all_loss"].sum().backward()

            optimizer.step()

            self.n_iters_elapsed += 1

            _timer.toc()

            lossMeter.__add__(losses)

            if performances is not None and all(performances):
                perfMeter.put(performances)

            if (i + 1) % self.cfg.N_ITERS_TO_DISPLAY_STATUS == 0:
                avg_losses = lossMeter.average()
                template = "[epoch {}/{}, iter {}, lr {}] Total train loss: {:.4f} " "(ips = {:.2f} )\n" "{}"
                self.logger.info(
                    template.format(
                        epoch,
                        self.cfg.N_MAX_EPOCHS,
                        i,
                        round(get_current_lr(optimizer), 6),
                        avg_losses["all_loss"],
                        self.batch_size * self.cfg.N_ITERS_TO_DISPLAY_STATUS /
                        _timer.total_time,
                        "\n".join([
                            "{}: {:.4f}".format(n, l)
                            for n, l in avg_losses.items() if n != "all_loss"
                        ]),
                    ))

                if self.cfg.TENSORBOARD:
                    tb_step = int((epoch * self.n_steps_per_epoch + i) /
                                  self.cfg.N_ITERS_TO_DISPLAY_STATUS)
                    # Logging train losses
                    [
                        self.tb_writer.add_scalar(f"loss/{prefix}_{n}", l,
                                                  tb_step)
                        for n, l in avg_losses.items()
                    ]

                lossMeter.clear()

            del imgs, labels, losses, performances

        lr_scheduler.step()

        if self.cfg.TENSORBOARD and len(perfMeter):
            avg_perf = perfMeter.average()
            [
                self.tb_writer.add_scalar(f"performance/{prefix}_{k}", v,
                                          epoch) for k, v in avg_perf.items()
            ]

        if self.cfg.TENSORBOARD_WEIGHT and False:
            for name, param in model.named_parameters():
                layer, attr = os.path.splitext(name)
                attr = attr[1:]
                self.tb_writer.add_histogram("{}/{}".format(layer, attr),
                                             param, epoch)
Пример #7
0
text_proposals_detector = TextProposalDetector(
    CaffeModel(NET_DEF_FILE, MODEL_FILE))
text_detector = TextDetector(text_proposals_detector)

demo_imnames = os.listdir(DEMO_IMAGE_DIR)
timer = Timer()

for im_name in demo_imnames:
    print "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~"
    print "Image: %s" % im_name

    im_file = osp.join(DEMO_IMAGE_DIR, im_name)
    im = cv2.imread(im_file)

    timer.tic()

    im, f = resize_im(im, cfg.SCALE, cfg.MAX_SCALE)
    text_lines = text_detector.detect(im)

    print "Number of the detected text lines: %s" % len(text_lines)
    print "Time: %f" % timer.toc()

    im_with_text_lines = draw_boxes(im,
                                    text_lines,
                                    caption=im_name,
                                    wait=False)

print "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~"
print "Thank you for trying our demo. Press any key to exit..."
cv2.waitKey(0)
Пример #8
0
    def train_epoch(self,
                    epoch,
                    model,
                    dataloader,
                    optimizer,
                    lr_scheduler,
                    grad_normalizer=None,
                    prefix="train"):
        model.train()

        _timer = Timer()
        lossLogger = MetricLogger(delimiter="  ")
        performanceLogger = MetricLogger(delimiter="  ")

        for i, (imgs, targets) in enumerate(dataloader):
            _timer.tic()
            # zero the parameter gradients
            optimizer.zero_grad()

            # imgs = imgs.cuda()
            imgs = list(
                img.cuda()
                for img in imgs) if isinstance(imgs, list) else imgs.cuda()
            # labels = [label.cuda() for label in labels] if isinstance(labels,list) else labels.cuda()
            # labels = [{k: v.cuda() for k, v in t.items()} for t in labels] if isinstance(labels,list) else labels.cuda()
            if isinstance(targets, list):
                if isinstance(targets[0], torch.Tensor):
                    targets = [t.cuda() for t in targets]
                else:
                    targets = [{k: v.cuda()
                                for k, v in t.items()} for t in targets]
            else:
                targets = targets.cuda()

            out = model(imgs, targets, prefix)

            if not isinstance(out, tuple):
                losses, performances = out, None
            else:
                losses, performances = out

            self.n_iters_elapsed += 1

            with amp.scale_loss(losses["loss"], optimizer) as scaled_loss:
                scaled_loss.backward()

            optimizer.step()

            torch.cuda.synchronize()
            _timer.toc()

            if (i + 1) % self.cfg.N_ITERS_TO_DISPLAY_STATUS == 0:
                if self.cfg.distributed:
                    # reduce losses over all GPUs for logging purposes
                    loss_dict_reduced = reduce_dict(losses)
                    lossLogger.update(**loss_dict_reduced)
                else:
                    lossLogger.update(**losses)

                if performances is not None and all(performances):
                    if self.cfg.distributed:
                        # reduce performances over all GPUs for logging purposes
                        performance_dict_reduced = reduce_dict(performances)
                        performanceLogger.update(**performance_dict_reduced)
                    else:
                        performanceLogger.update(**performances)

                if self.cfg.local_rank == 0:
                    template = "[epoch {}/{}, iter {}, lr {}] Total train loss: {:.4f} " "(ips = {:.2f})\n" "{}"
                    logger.info(
                        template.format(
                            epoch,
                            self.cfg.N_MAX_EPOCHS,
                            i,
                            round(get_current_lr(optimizer), 6),
                            lossLogger.meters["loss"].value,
                            self.batch_size *
                            self.cfg.N_ITERS_TO_DISPLAY_STATUS /
                            _timer.total_time,
                            "\n".join([
                                "{}: {:.4f}".format(n, l.value)
                                for n, l in lossLogger.meters.items()
                                if n != "loss"
                            ]),
                        ))

            del imgs, targets

        if self.cfg.TENSORBOARD and self.cfg.local_rank == 0:
            # Logging train losses
            [
                self.tb_writer.add_scalar(f"loss/{prefix}_{n}", l.global_avg,
                                          epoch)
                for n, l in lossLogger.meters.items()
            ]
            if len(performanceLogger.meters):
                [
                    self.tb_writer.add_scalar(f"performance/{prefix}_{k}",
                                              v.global_avg, epoch)
                    for k, v in performanceLogger.meters.items()
                ]

        if self.cfg.TENSORBOARD_WEIGHT and False:
            for name, param in model.named_parameters():
                layer, attr = os.path.splitext(name)
                attr = attr[1:]
                self.tb_writer.add_histogram("{}/{}".format(layer, attr),
                                             param, epoch)