Esempio n. 1
0
    def __init__(self, root, image_set, is_train, transform=None):
        super().__init__(root, image_set, is_train, transform)

        anno_file = osp.join(self.root, 'h36m', 'annot',
                             'h36m_{}.pkl'.format(image_set))
        self.db = self.load_db(anno_file)

        if cfg.DATASETS.H36M.FILTER_DAMAGE:
            print('before filter', len(self.db))
            self.db = [
                db_rec for db_rec in self.db if not self.isdamaged(db_rec)
            ]
            print('after filter', len(self.db))

        if cfg.DATASETS.H36M.MAPPING:
            assert cfg.KEYPOINT.NUM_PTS == 20
            self.u2a_mapping = super().get_mapping()
            super().do_mapping()
        else:
            assert cfg.KEYPOINT.NUM_PTS == 17

        self.grouping = self.get_group(self.db)
        self.group_size = len(self.grouping)

        if cfg.VIS.MULTIVIEWH36M:
            from utils.metric_logger import MetricLogger
            self.meters = MetricLogger()
Esempio n. 2
0
    def eval(self, iteration=-1, summary_writer=None):
        start = time.time()
        all_preds = []
        all_labels = []
        with torch.no_grad():
            # self.derenderer.eval()
            val_metrics = MetricLogger(delimiter="  ")
            val_loss_logger = MetricLogger(delimiter="  ")
            for i, inputs in enumerate(self.val_loader, iteration):
                # data_time = time.time() - last_batch_time

                if torch.cuda.is_available():
                    inputs = to_cuda(inputs)

                output = self.derenderer(inputs)

                loss_dict = gather_loss_dict(output)
                val_loss_logger.update(**loss_dict)
                summary_writer.add_scalars("val_non_smooth", val_loss_logger.last_item, i)

                all_preds.append({k:v.cpu().numpy() for k,v in output["output"].items()})
                all_labels.append({k:v.cpu().numpy() for k,v in inputs["attributes"].items()})
                # all_labels = self.attributes.cat_by_key(all_labels, inputs["attributes"])
                # all_preds = self.attributes.cat_by_key(all_preds, output['output'])


                # batch_time = time.time() - last_batch_time
                # val_metrics.update(time=batch_time, data=data_time)
                if time.time() - start > self.cfg.SOLVER.VALIDATION_MAX_SECS:
                    break

            all_preds, all_labels = map(lambda l: {k: np.concatenate([a[k] for a in l]) for k in l[0].keys()},
                                        [all_preds, all_labels])
            # all_preds = {k: np.concatenate([a[k] for a in all_preds]) for k in all_preds[0].keys()}
            # all_labels = {k: np.concatenate([a[k] for a in all_labels]) for k in all_labels[0].keys()}
            err_dict = self.attributes.pred_error(all_preds, all_labels)
            val_metrics.update(**err_dict)
            log.info(val_metrics.delimiter.join(["VALIDATION", "iter: {iter}", "{meters}"])
                     .format(iter=iteration, meters=str(val_metrics)))
            log.info(val_metrics.delimiter.join(["VALIDATION", "iter: {iter}", "{meters}"])
                     .format(iter=iteration, meters=str(val_loss_logger)))
            if summary_writer is not None:
                summary_writer.add_scalars("val_error", val_metrics.mean, iteration)
                summary_writer.add_scalars("val", val_loss_logger.mean, iteration)
            # self.derenderer.train()
        return err_dict
Esempio n. 3
0
def train_val(model, loaders, optimizer, scheduler, losses, metrics=None):
    n_epochs = cfg.SOLVER.NUM_EPOCHS
    end = time.time()
    best_dice = 0.0
    for epoch in range(n_epochs):
        scheduler.step()
        for phase in ['train', 'eval']:
            meters = MetricLogger(delimiter=" ")
            loader = loaders[phase]
            getattr(model, phase)()
            logger = logging.getLogger(phase)
            total = len(loader)
            for batch_id, (batch_x, batch_y) in enumerate(loader):
                batch_x = batch_x.cuda(async=True)
                batch_y = batch_y.cuda(async=True)
                with torch.set_grad_enabled(phase == 'train'):
                    output, vout, mu, logvar = model(batch_x)
                    loss_dict = losses['dice_vae'](cfg, output, batch_x, batch_y, vout, mu, logvar)
                meters.update(**loss_dict)
                if phase == 'train':
                    optimizer.zero_grad()
                    loss_dict['loss'].backward()
                    optimizer.step()
                else:
                    if metrics and (epoch + 1) % 20 == 0:
                        with torch.no_grad():
                            hausdorff = metrics['hd']
                            metric_dict = hausdorff(output, batch_y)
                            meters.update(**metric_dict)
                        save_sample(output, batch_x, batch_y, epoch, batch_id)
                logger.info(meters.delimiter.join([f"Epoch: {epoch}, Batch:{batch_id}/{total}",
                                                   f"{str(meters)}",
                                                   f"Time: {time.time() - end: .3f}"
                                                   ]))
                end = time.time()

            if phase == 'eval':
                dice = 1 - (meters.wt_loss.global_avg + meters.tc_loss.global_avg + meters.et_loss.global_avg) / 3
                state = {}
                if len(cfg.GPU.ID>1):
                    state['model'] = model.module.state_dict()
                else:
                    state['model'] = model.state_dict()
                state['optimizer'] = optimizer.state_dict()
                file_name = os.path.join(cfg.LOG_DIR, cfg.TASK_NAME, 'epoch' + str(epoch) + '.pt')
                torch.save(state, file_name)
                if dice > best_dice:
                    best_dice = dice
                    shutil.copyfile(file_name, os.path.join(cfg.LOG_DIR, cfg.TASK_NAME, 'best_model.pth'))

    return model
def do_evaluation(cfg, model, data_loader_val, device, arguments, summary_writer):
    # get logger
    logger = logging.getLogger(cfg.NAME)
    logger.info("Start evaluation  ...")
    if isinstance(model, DistributedDataParallel):
        model = model.module
    model.eval()
    meters_val = MetricLogger(delimiter="  ")
    # TODO: add compare module which can test sampled train set
    # bar = TqdmBar(data_loader_val, 0, get_rank(), data_loader_val.__len__(),
    #               description='Validation', use_bar=cfg.USE_BAR)
    # for iteration, record in bar.bar:
    #     record = move_to_device(record, device)
    #     loss, prediction = model(record)
    #     # reduce losses over all GPUs for logging purposes
    #     loss_reduced = {key: value.cpu().item() for key, value in loss.items()}
    #     meters_val.update(**loss_reduced)
    # bar.close()
    # logger.info(
    #     meters_val.delimiter.join(
    #         [
    #             "[Validation]: ",
    #             "iter: {iter}",
    #             "{meters}",
    #             "mem: {memory:.0f}",
    #         ]
    #     ).format(
    #         iter=arguments["iteration"],
    #         meters=str(meters_val),
    #         memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
    #     )
    # )
    record = {name: meter.global_avg for name, meter in meters_val.meters}
    write_summary(summary_writer, arguments["iteration"], record=record, group='Valid_Losses')

    inference = build_inference(cfg)
    model.eval()
    inference(cfg, model, data_loader_val, device, iteration=arguments["iteration"], summary_writer=summary_writer,
              logger=logger, visualize=False)
    model.train()
Esempio n. 5
0
class MultiViewH36M(JointsDataset):
    actual_joints = {
        0: 'root',
        1: 'rhip',
        2: 'rkne',
        3: 'rank',
        4: 'lhip',
        5: 'lkne',
        6: 'lank',
        7: 'belly',
        8: 'neck',
        9: 'nose',
        10: 'head',
        11: 'lsho',
        12: 'lelb',
        13: 'lwri',
        14: 'rsho',
        15: 'relb',
        16: 'rwri'
    }

    def __init__(self, root, image_set, is_train, transform=None):
        super().__init__(root, image_set, is_train, transform)

        anno_file = osp.join(self.root, 'h36m', 'annot',
                             'h36m_{}.pkl'.format(image_set))
        self.db = self.load_db(anno_file)

        if cfg.DATASETS.H36M.FILTER_DAMAGE:
            print('before filter', len(self.db))
            self.db = [
                db_rec for db_rec in self.db if not self.isdamaged(db_rec)
            ]
            print('after filter', len(self.db))

        if cfg.DATASETS.H36M.MAPPING:
            assert cfg.KEYPOINT.NUM_PTS == 20
            self.u2a_mapping = super().get_mapping()
            super().do_mapping()
        else:
            assert cfg.KEYPOINT.NUM_PTS == 17

        self.grouping = self.get_group(self.db)
        self.group_size = len(self.grouping)

        if cfg.VIS.MULTIVIEWH36M:
            from utils.metric_logger import MetricLogger
            self.meters = MetricLogger()

    @staticmethod
    def index_to_action_names():
        return {
            2: 'Direction',
            3: 'Discuss',
            4: 'Eating',
            5: 'Greet',
            6: 'Phone',
            7: 'Pose',
            8: 'Purchase',
            9: 'Sitting',
            10: 'SittingDown',
            11: 'Smoke',
            12: 'Photo',
            13: 'Wait',
            14: 'WalkDog',
            15: 'Walk',
            16: 'WalkTo'
        }

    def load_db(self, dataset_file):
        with open(dataset_file, 'rb') as f:
            dataset = pickle.load(f)
            return dataset

    def get_group(self, db):
        grouping = {}
        nitems = len(db)
        for i in range(nitems):
            keystr = self.get_key_str(db[i])
            camera_id = db[i]['camera_id']
            if keystr not in grouping:
                grouping[keystr] = [-1, -1, -1, -1]
            grouping[keystr][camera_id] = i

        filtered_grouping = []
        for _, v in grouping.items():
            if np.all(np.array(v) != -1):
                filtered_grouping.append(v)

        if self.is_train:
            if cfg.DATASETS.H36M.TRAIN_SAMPLE:
                filtered_grouping = filtered_grouping[::cfg.DATASETS.H36M.
                                                      TRAIN_SAMPLE]
        else:
            if cfg.DATASETS.H36M.TEST_SAMPLE:
                filtered_grouping = filtered_grouping[::cfg.DATASETS.H36M.
                                                      TEST_SAMPLE]

        return filtered_grouping

    def __getitem__(self, idx):
        if cfg.VIS.H36M:
            return super().__getitem__(idx)
        items = self.grouping[idx].copy()
        data = {}
        d = {}
        for cam, item in enumerate(items):
            datum = super().__getitem__(item)
            data[cam] = datum
            d[cam] = datum['KRT']
        rank = neighbor_cameras(d)

        if self.is_train:
            #TODO our training is shorter than the original code below
            if cfg.EPIPOLAR.TOPK == 3:
                # 0~3
                ref_cam, other_cam = np.random.choice(len(items),
                                                      2,
                                                      replace=False)
            elif cfg.EPIPOLAR.TOPK == 2:
                ref_cam = np.random.randint(len(items))
                other_cam = np.random.choice(rank[ref_cam][0][:2])
            elif cfg.EPIPOLAR.TOPK == 1:
                ref_cam = np.random.randint(len(items))
                # ref_cam = random.choice(len(items))
                other_cam = rank[ref_cam][0][0]
            else:
                raise NotImplementedError

            ret = data[ref_cam]
            other_item = data[other_cam]
            if cfg.EPIPOLAR.PRIOR:
                ret['camera'] = ref_cam
                ret['other_camera'] = other_cam
            for i in ['img', 'KRT', 'heatmap', 'img-path']:
                ret['other_' + i] = other_item[i]
        # print('multi h36m this view', ret['img-path'])
        # print('multi h36m other image', ret['other_img-path'])

            if cfg.VIS.MULTIVIEWH36M:
                from vision.multiview import findFundamentalMat, camera_center
                import matplotlib.pyplot as plt
                from data.transforms.image import de_transform
                P1 = ret['KRT']
                P2 = ret['other_KRT']
                C, _ = camera_center(P1)
                C_ = np.ones(4, dtype=C.dtype)
                C_[:3] = C
                e2 = P2 @ C_
                e2 /= e2[2]
                # world3d = ret['R'].T @ ret['points-3d'].T  + ret['T']
                othercam3d = other_item['R'] @ (ret['points-3d'].T -
                                                other_item['T'])
                other2d = other_item['MultiViewH36M'] @ othercam3d
                other2d /= other2d[-1]
                N = len(other_item['points-2d'])
                # print(other2d)
                # import matplotlib.pyplot as plt
                # plt.imshow(other_item['img'].cpu().numpy().transpose((1,2,0)))
                # plt.scatter(other2d[0], other2d[1])
                # plt.show()
                F = findFundamentalMat(P1, P2, engine='numpy')
                print(F)
                # points_2d =
                print(ret['points-2d'])
                # ls = F @ np.concatenate((ret['points-2d']*4, np.ones((N, 1))), 1).T
                test_points = np.concatenate((np.ones(
                    (N, 1)) * 128, np.linspace(
                        10, 250, N)[:, None], np.ones((N, 1))), 1)
                C2, _ = camera_center(P2)
                C_ = np.ones(4, dtype=C.dtype)
                C_[:3] = C2
                e1 = P1 @ C_
                e1 /= e1[2]
                l1s = np.cross(test_points, e1)
                ls = F @ test_points.T
                # res = np.concatenate((other_item['points-2d'], np.ones((N, 1))), 1)  @ F @ np.concatenate((ret['points-2d'], np.ones((N, 1))), 1).T
                # print(res)
                fig = plt.figure(1)
                ax1 = fig.add_subplot(121)
                ax2 = fig.add_subplot(122)
                ax1.imshow(
                    de_transform(ret['img']).cpu().numpy().transpose(
                        (1, 2, 0))[..., ::-1])
                ax2.imshow(
                    de_transform(other_item['img']).cpu().numpy().transpose(
                        (1, 2, 0))[..., ::-1])

                def scatterline(l):
                    x = np.arange(0, 256)
                    y = (-l[2] - l[0] * x) / l[1]
                    mask = (y < 256) & (y > 0)
                    return x[mask], y[mask]

                for a, b in zip(ret['points-2d'][:, 0], ret['points-2d'][:,
                                                                         1]):
                    ax1.scatter(a, b, color='red')
                for a, b in zip(other_item['points-2d'][:, 0],
                                other_item['points-2d'][:, 1]):
                    ax2.scatter(a, b, color='green')
                for a, b in zip(other2d[0], other2d[1]):
                    ax2.scatter(a, b, color='red')
                for idx, (l, l1) in enumerate(zip(ls.T, l1s)):
                    # ax1.imshow(ret['original_image'][..., ::-1])
                    x1, y1 = scatterline(l1)
                    ax1.scatter(x1, y1, s=1)
                    # ax1.scatter(ret['points-2d'][idx, 0]*4, ret['points-2d'][idx, 1]*4, color='red')
                    # ax2.imshow(other_item['original_image'])
                    # ax2.scatter(other_item['points-2d'][idx, 0]*4, other_item['points-2d'][idx, 1]*4, color='yellow')
                    x, y = scatterline(l)
                    ax2.scatter(x, y, s=1)
                    # ax2.scatter(e2[0], e2[1], color='green')
                plt.show()

            return {k: totensor(v) for k, v in ret.items()}
        else:
            ret = {'camera': []}
            for k in datum.keys():
                ret[k] = []
            for k in ['img', 'KRT', 'heatmap', 'camera', 'img-path']:
                ret['other_' + k] = []
            for ref_cam, datum in data.items():
                ret['camera'].append(ref_cam)
                other_cam = rank[ref_cam][0][0]
                ret['other_camera'].append(other_cam)
                for k, v in datum.items():
                    ret[k].append(v)
                for k in ['img', 'KRT', 'heatmap', 'img-path']:
                    ret['other_' + k].append(data[other_cam][k])
            if cfg.KEYPOINT.NUM_CAM:
                for k in ret:
                    ret[k] = ret[k][:cfg.KEYPOINT.NUM_CAM]
            for k in ret:
                if not k in ['img-path', 'other_img-path']:
                    ret[k] = np.stack(ret[k])
            if cfg.DATASETS.H36M.REAL3D:
                real3d = self.computereal3d(ret['points-2d'], ret['K'],
                                            ret['RT'])
                ret['points-3d'][:] = real3d
            if cfg.VIS.MULTIVIEWH36M:
                return {
                    k: totensor(v) if isinstance(v, torch.Tensor) else v
                    for k, v in ret.items()
                }
                self.computereal3d(ret['points-2d'], ret['K'], ret['RT'],
                                   ret['KRT'], ret['points-3d'][0])
                return ret
            return {
                k: totensor(v) if isinstance(v, torch.Tensor) else v
                for k, v in ret.items()
            }

    def __len__(self):
        if cfg.VIS.H36M:
            return super().__len__()
        return self.group_size

    def get_key_str(self, datum):
        return 's_{:02}_act_{:02}_subact_{:02}_imgid_{:06}'.format(
            datum['subject'], datum['action'], datum['subaction'],
            datum['image_id'])

    def evaluate(self, pred, *args, **kwargs):
        pred = pred.copy()

        headsize = self.image_size[0] / 10.0
        threshold = 0.5

        u2a = self.u2a_mapping
        a2u = {v: k for k, v in u2a.items() if v != '*'}
        a = list(a2u.keys())
        u = list(a2u.values())
        indexes = list(range(len(a)))
        indexes.sort(key=a.__getitem__)
        sa = list(map(a.__getitem__, indexes))
        su = np.array(list(map(u.__getitem__, indexes)))

        gt = []
        for items in self.grouping:
            for item in items:
                gt.append(self.db[item]['joints_2d'][su, :2])
        gt = np.array(gt)
        pred = pred[:, su, :2]

        distance = np.sqrt(np.sum((gt - pred)**2, axis=2))
        detected = (distance <= headsize * threshold)

        joint_detection_rate = np.sum(detected, axis=0) / np.float(gt.shape[0])

        name_values = collections.OrderedDict()
        joint_names = self.actual_joints
        for i in range(len(a2u)):
            name_values[joint_names[sa[i]]] = joint_detection_rate[i]
        return name_values, np.mean(joint_detection_rate)

    def computereal3d(self, pts, Ks, RTs, KRTs=None, gt3ds=None):
        from vision.triangulation import triangulate_pymvg
        if cfg.DATASETS.H36M.MAPPING:
            actualjoints = np.array(
                [0, 1, 2, 3, 4, 5, 6, 7, 9, 11, 12, 14, 15, 16, 17, 18, 19])
            pts = pts[:, actualjoints]
        confs = np.ones((pts.shape[0], pts.shape[1]))
        real3ds = triangulate_pymvg(pts, Ks, RTs, confs)
        if not cfg.VIS.MULTIVIEWH36M:
            return real3ds

        gt3ds = gt3ds.cpu().numpy()
        KRTs = KRTs.cpu().numpy()
        pts = pts.cpu().numpy()

        # print('3d delta')
        # print(gt3ds-real3ds)
        # print('3d error')
        # print(np.linalg.norm(gt3ds-real3ds, axis=1))

        def reprojerr(real3ds):
            real2ds = []
            real3ds = np.concatenate((real3ds, np.ones((len(real3ds), 1))), 1)
            for KRT in KRTs:
                real2d = KRT @ real3ds.T
                real2d /= real2d[-1]
                real2ds.append(real2d[:2].T)
            # views x Njoints x 2
            real2ds = np.stack(real2ds)
            delta = real2ds - pts
            return np.linalg.norm(delta, axis=2).sum()

        err0 = reprojerr(gt3ds)
        err1 = reprojerr(real3ds)
        self.meters.update(gt3d=err0, real3d=err1)
        print(self.meters)
Esempio n. 6
0
        return None, None
    gt_class_scores = np.ones(num_gt_boxes)
    gt_predicate_scores = np.ones(num_gt_relations)
    gt_triplets, gt_triplet_boxes, _ = _triplet(gt_pred_labels, gt_relations,
                                                gt_classes, gt_boxes,
                                                gt_predicate_scores,
                                                gt_class_scores)
    return gt_triplets, gt_triplet_boxes


if __name__ == '__main__':
    info = json.load(
        open(os.path.join(cfg.DATASET.PATH, "VG-SGG-dicts.json"), 'r'))
    itola = info['idx_to_label']
    itopred = info['idx_to_predicate']
    meters = MetricLogger(delimiter="  ")
    data_loader = build_data_loader(cfg)
    end = time.time()
    logger = setup_logger("scene_graph_generation", "logs", get_rank())
    output_config_path = os.path.join("logs", 'config.yml')
    logger.info("Saving config into: {}".format(output_config_path))

    logger = logging.getLogger("scene_graph_generation")
    logger.info("Start training")
    max_iter = len(data_loader)
    result_dic: {str: int} = defaultdict(int)

    all_images = 0
    with open('browse_data.txt', 'w') as f:
        for i, data in enumerate(data_loader):
            data_time = time.time() - end
Esempio n. 7
0
    def train(self, log_flag=True):
        train_metrics = MetricLogger(delimiter="  ")
        summary_writer = SummaryWriter(log_dir=os.path.join(self.output_dir, "summary"))

        self.derenderer.train()

        # Initialize timing
        timers  =  create_new_timer()

        done = False
        while not done:
            for iteration, inputs in enumerate(self.train_loader, self.start_iteration):
                iter_time = time.time()
                data_time = iter_time - timers.batch

                if torch.cuda.is_available():
                    inputs = to_cuda(inputs)

                output = self.derenderer(inputs)

                loss_dict = gather_loss_dict(output)
                loss = loss_dict['loss']
                # loss = sum([loss_dict[term] for term in ['x', 'y', 'z']])

                if torch.isnan(loss).any():
                    raise Nan_Exception()

                train_metrics.update(**loss_dict)
                summary_writer.add_scalars("train_non_smooth", train_metrics.last_item, iteration)


                batch_time = iter_time - timers.batch
                timers.batch = iter_time
                train_metrics.update(time=batch_time, data=data_time)
                eta_seconds = timers.start + self.cfg.SOLVER.MAX_TIME_SECS - iter_time
                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

                if (iter_time - timers.log > self.cfg.SOLVER.PRINT_METRICS_TIME and log_flag):
                    timers.log = iter_time
                    log.info(train_metrics.delimiter.join(["eta: {eta}", "iter: {iter}", "{meters}", "lr: {lr:.6f}",
                                                           "max mem: {memory:.0f}"]).format(
                        eta=eta_string,
                        iter=iteration,
                        meters=str(train_metrics),
                        lr=self.optimizer.param_groups[0]["lr"],
                        memory=proc_id.memory_info().rss / 1e9)
                    )
                    summary_writer.add_scalars("train", train_metrics.mean, iteration)

                if iter_time - timers.checkpoint > self.cfg.SOLVER.CHECKPOINT_SECS: #iteration % checkpoint_period == 0:
                    timers.checkpoint = iter_time
                    self.checkpointer.save("model_{:07d}".format(iteration))

                if iter_time - timers.tensorboard > self.cfg.SOLVER.TENSORBOARD_SECS or self.cfg.DEBUG:
                    timers.tensorboard = iter_time
                    summary_writer.add_scalars("train", train_metrics.mean, iteration)


                if iter_time - timers.start > self.cfg.SOLVER.MAX_TIME_SECS:
                    log.info("finished training loop in {}".format(iter_time-timers.start))
                    done = True
                    break

                if iter_time - timers.validation > self.cfg.SOLVER.VALIDATION_SECS:
                    err_dict = self.eval(iteration, summary_writer)
                    timers.validation = time.time()

                loss.backward()
                self.optimizer.step()
                self.scheduler.step()
                self.optimizer.zero_grad()

            log.info("*******  epoch done  after {}  *********".format(time.time() - timers.epoch))
            timers.epoch = time.time()
            self.start_iteration = iteration

        err_dict = self.eval(iteration, summary_writer)

        self.checkpointer.save("model_{:07d}".format(iteration))
        summary_writer.close()
        return err_dict
Esempio n. 8
0
def test(cfg, model=None):
    torch.cuda.empty_cache()  # TODO check if it helps
    cpu_device = torch.device("cpu")
    if cfg.VIS.FLOPS:
        # device = cpu_device
        device = torch.device("cuda:0")
    else:
        device = torch.device(cfg.DEVICE)
    if model is None:
        # load model from outputs
        model = Modelbuilder(cfg)
        model.to(device)
        checkpointer = Checkpointer(model, save_dir=cfg.OUTPUT_DIR)
        _ = checkpointer.load(cfg.WEIGHTS)
    data_loaders = make_data_loader(cfg, is_train=False)
    if cfg.VIS.FLOPS:
        model.eval()
        from thop import profile
        for idx, batchdata in enumerate(data_loaders[0]):
            with torch.no_grad():
                flops, params = profile(
                    model,
                    inputs=({
                        k: v.to(device) if isinstance(v, torch.Tensor) else v
                        for k, v in batchdata.items()
                    }, False))
                print('flops', flops, 'params', params)
                exit()
    if cfg.TEST.RECOMPUTE_BN:
        tmp_data_loader = make_data_loader(cfg,
                                           is_train=True,
                                           dataset_list=cfg.DATASETS.TEST)
        model.train()
        for idx, batchdata in enumerate(tqdm(tmp_data_loader)):
            with torch.no_grad():
                model(
                    {
                        k: v.to(device) if isinstance(v, torch.Tensor) else v
                        for k, v in batchdata.items()
                    },
                    is_train=True)
        #cnt = 0
        #while cnt < 1000:
        #    for idx, batchdata in enumerate(tqdm(tmp_data_loader)):
        #        with torch.no_grad():
        #            model({k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batchdata.items()}, is_train=True)
        #        cnt += 1
        checkpointer.save("model_bn")
        model.eval()
    elif cfg.TEST.TRAIN_BN:
        model.train()
    else:
        model.eval()
    dataset_names = cfg.DATASETS.TEST
    meters = MetricLogger()

    #if cfg.TEST.PCK and cfg.DOTEST and 'h36m' in cfg.OUTPUT_DIR:
    #    all_preds = np.zeros((len(data_loaders), cfg.KEYPOINT.NUM_PTS, 3), dtype=np.float32)
    cpu = lambda x: x.to(cpu_device).numpy() if isinstance(x, torch.Tensor
                                                           ) else x

    logger = setup_logger("tester", cfg.OUTPUT_DIR)
    for data_loader, dataset_name in zip(data_loaders, dataset_names):
        print('Loading ', dataset_name)
        dataset = data_loader.dataset

        logger.info("Start evaluation on {} dataset({} images).".format(
            dataset_name, len(dataset)))
        total_timer = Timer()
        total_timer.tic()

        predictions = []
        #if 'h36m' in cfg.OUTPUT_DIR:
        #    err_joints = 0
        #else:
        err_joints = np.zeros((cfg.TEST.IMS_PER_BATCH, int(cfg.TEST.MAX_TH)))
        total_joints = 0

        for idx, batchdata in enumerate(tqdm(data_loader)):
            if cfg.VIS.VIDEO and not 'h36m' in cfg.OUTPUT_DIR:
                for k, v in batchdata.items():
                    try:
                        #good 1 2 3 4 5 6 7 8 12 16 30
                        # 4 17.4 vs 16.5
                        # 30 41.83200 vs 40.17562
                        #bad 0 22
                        #0 43.78544 vs 45.24059
                        #22 43.01385 vs 43.88636
                        vis_idx = 16
                        batchdata[k] = v[:, vis_idx, None]
                    except:
                        pass
            if cfg.VIS.VIDEO_GT:
                for k, v in batchdata.items():
                    try:
                        vis_idx = 30
                        batchdata[k] = v[:, vis_idx:vis_idx + 2]
                    except:
                        pass
                joints = cpu(batchdata['points-2d'].squeeze())[0]
                orig_img = de_transform(
                    cpu(batchdata['img'].squeeze()[None, ...])[0][0])
                # fig = plt.figure()
                # ax = fig.add_subplot(111)
                ax = display_image_in_actual_size(orig_img.shape[1],
                                                  orig_img.shape[2])
                if 'h36m' in cfg.OUTPUT_DIR:
                    draw_2d_pose(joints, ax)
                    orig_img = orig_img[::-1]
                else:
                    visibility = cpu(batchdata['visibility'].squeeze())[0]
                    plot_two_hand_2d(joints, ax, visibility)
                    # plot_two_hand_2d(joints, ax)
                ax.imshow(orig_img.transpose((1, 2, 0)))
                ax.axis('off')
                output_folder = os.path.join("outs", "video_gt", dataset_name)
                mkdir(output_folder)
                plt.savefig(os.path.join(output_folder, "%08d" % idx),
                            bbox_inches="tight",
                            pad_inches=0)
                plt.cla()
                plt.clf()
                plt.close()
                continue
            #print('batchdatapoints-3d', batchdata['points-3d'])
            batch_size = cfg.TEST.IMS_PER_BATCH
            with torch.no_grad():
                loss_dict, metric_dict, output = model(
                    {
                        k: v.to(device) if isinstance(v, torch.Tensor) else v
                        for k, v in batchdata.items()
                    },
                    is_train=False)
            meters.update(**prefix_dict(loss_dict, dataset_name))
            meters.update(**prefix_dict(metric_dict, dataset_name))
            # udpate err_joints
            if cfg.VIS.VIDEO:
                joints = cpu(output['batch_locs'].squeeze())
                if joints.shape[0] == 1:
                    joints = joints[0]
                try:
                    orig_img = de_transform(
                        cpu(batchdata['img'].squeeze()[None, ...])[0][0])
                except:
                    orig_img = de_transform(
                        cpu(batchdata['img'].squeeze()[None, ...])
                        [0])  # fig = plt.figure()
                # ax = fig.add_subplot(111)
                ax = display_image_in_actual_size(orig_img.shape[1],
                                                  orig_img.shape[2])
                if 'h36m' in cfg.OUTPUT_DIR:
                    draw_2d_pose(joints, ax)
                    orig_img = orig_img[::-1]
                else:
                    visibility = cpu(batchdata['visibility'].squeeze())
                    if visibility.shape[0] == 1:
                        visibility = visibility[0]
                    plot_two_hand_2d(joints, ax, visibility)
                ax.imshow(orig_img.transpose((1, 2, 0)))
                ax.axis('off')
                output_folder = os.path.join(cfg.OUTPUT_DIR, "video",
                                             dataset_name)
                mkdir(output_folder)
                plt.savefig(os.path.join(output_folder, "%08d" % idx),
                            bbox_inches="tight",
                            pad_inches=0)
                plt.cla()
                plt.clf()
                plt.close()
                # plt.show()

            if cfg.TEST.PCK and cfg.DOTEST:
                #if 'h36m' in cfg.OUTPUT_DIR:
                #    err_joints += metric_dict['accuracy'] * output['total_joints']
                #    total_joints += output['total_joints']
                #    # all_preds
                #else:
                for i in range(batch_size):
                    err_joints = np.add(err_joints, output['err_joints'])
                    total_joints += sum(output['total_joints'])

            if idx % cfg.VIS.SAVE_PRED_FREQ == 0 and (
                    cfg.VIS.SAVE_PRED_LIMIT == -1
                    or idx < cfg.VIS.SAVE_PRED_LIMIT * cfg.VIS.SAVE_PRED_FREQ):
                # print(meters)
                for i in range(batch_size):
                    predictions.append((
                        {
                            k: (cpu(v[i]) if not isinstance(v, int) else v)
                            for k, v in batchdata.items()
                        },
                        {
                            k: (cpu(v[i]) if not isinstance(v, int) else v)
                            for k, v in output.items()
                        },
                    ))
            if cfg.VIS.SAVE_PRED_LIMIT != -1 and idx > cfg.VIS.SAVE_PRED_LIMIT * cfg.VIS.SAVE_PRED_FREQ:
                break

            # if not cfg.DOTRAIN and cfg.SAVE_PRED:
            #     if cfg.VIS.SAVE_PRED_LIMIT != -1 and idx < cfg.VIS.SAVE_PRED_LIMIT:
            #         for i in range(batch_size):
            #             predictions.append(
            #                     (
            #                         {k: (cpu(v[i]) if not isinstance(v, int) else v) for k, v in batchdata.items()},
            #                         {k: (cpu(v[i]) if not isinstance(v, int) else v) for k, v in output.items()},
            #                     )
            #             )
            #     if idx == cfg.VIS.SAVE_PRED_LIMIT:
            #         break
        #if cfg.TEST.PCK and cfg.DOTEST and 'h36m' in cfg.OUTPUT_DIR:
        #    logger.info('accuracy0.5: {}'.format(err_joints/total_joints))
        # dataset.evaluate(all_preds)
        # name_value, perf_indicator = dataset.evaluate(all_preds)
        # names = name_value.keys()
        # values = name_value.values()
        # num_values = len(name_value)
        # logger.info(' '.join(['| {}'.format(name) for name in names]) + ' |')
        # logger.info('|---' * (num_values) + '|')
        # logger.info(' '.join(['| {:.3f}'.format(value) for value in values]) + ' |')

        total_time = total_timer.toc()
        total_time_str = get_time_str(total_time)
        logger.info("Total run time: {} ".format(total_time_str))

        if cfg.OUTPUT_DIR:  #and cfg.VIS.SAVE_PRED:
            output_folder = os.path.join(cfg.OUTPUT_DIR, "inference",
                                         dataset_name)
            mkdir(output_folder)
            torch.save(predictions,
                       os.path.join(output_folder, cfg.VIS.SAVE_PRED_NAME))
            if cfg.DOTEST and cfg.TEST.PCK:
                print(err_joints.shape)
                torch.save(err_joints * 1.0 / total_joints,
                           os.path.join(output_folder, "pck.pth"))

    logger.info("{}".format(str(meters)))

    model.train()
    return meters.get_all_avg()
Esempio n. 9
0
    def train(self, resume=False, from_save_folder=False):
        if resume:
            self.resume_training_load(from_save_folder)
        self.logger.info("Start training")
        meters = MetricLogger(delimiter="  ")
        max_iter = len(self.train_loader)

        self.model.train()

        end = time.time()

        running_loss = 0.
        running_loss_classifier = 0.
        running_loss_box_reg = 0.
        running_loss_mask = 0.
        running_loss_objectness = 0.
        running_loss_rpn_box_reg = 0.
        running_loss_mimicking_cls = 0.
        running_loss_mimicking_cos_sim = 0.

        val_loss = None
        bbox_mmap = None
        segm_mmap = None

        start_step = self.step
        for _, (images, targets,
                _) in tqdm(enumerate(self.train_loader, start_step)):
            data_time = time.time() - end
            self.step += 1
            self.schedule_lr()

            self.optimizer.zero_grad()

            images = images.to(self.device)
            targets = [target.to(self.device) for target in targets]

            loss_dict = self.model(images, targets)
            loss_dict = self.weight_loss(loss_dict)

            losses = sum(loss for loss in loss_dict.values())

            losses.backward()
            self.optimizer.step()

            torch.cuda.empty_cache()

            meters.update(loss=losses, **loss_dict)
            running_loss += losses.item()
            running_loss_classifier += loss_dict['loss_classifier']
            running_loss_box_reg += loss_dict['loss_box_reg']
            running_loss_mask += loss_dict['loss_mask']
            running_loss_objectness += loss_dict['loss_objectness']
            running_loss_rpn_box_reg += loss_dict['loss_rpn_box_reg']
            running_loss_mimicking_cls += loss_dict['loss_mimicking_cls']
            running_loss_mimicking_cos_sim += loss_dict[
                'loss_mimicking_cos_sim']

            if self.step != 0:
                if self.step % self.board_loss_every == 0:
                    self.board_scalars(
                        'train', running_loss / self.board_loss_every,
                        running_loss_classifier / self.board_loss_every,
                        running_loss_box_reg / self.board_loss_every,
                        running_loss_mask / self.board_loss_every,
                        running_loss_objectness / self.board_loss_every,
                        running_loss_rpn_box_reg / self.board_loss_every,
                        running_loss_mimicking_cls / self.board_loss_every,
                        running_loss_mimicking_cos_sim / self.board_loss_every)
                    running_loss = 0.
                    running_loss_classifier = 0.
                    running_loss_box_reg = 0.
                    running_loss_mask = 0.
                    running_loss_objectness = 0.
                    running_loss_rpn_box_reg = 0.
                    running_loss_mimicking_cls = 0.
                    running_loss_mimicking_cos_sim = 0.

                if self.step % self.evaluate_every == 0:
                    self.model.train()
                    val_loss, val_loss_classifier, \
                    val_loss_box_reg, \
                    val_loss_mask, \
                    val_loss_objectness, \
                    val_loss_rpn_box_reg, \
                    val_loss_mimicking_cls, \
                    val_loss_mimicking_cos_sim= self.evaluate(num = self.cfg.SOLVER.EVAL_NUM)
                    self.board_scalars('val', val_loss,
                                       val_loss_classifier.item(),
                                       val_loss_box_reg.item(),
                                       val_loss_mask.item(),
                                       val_loss_objectness.item(),
                                       val_loss_rpn_box_reg.item(),
                                       val_loss_mimicking_cls.item(),
                                       val_loss_mimicking_cos_sim.item())

                if self.step % self.board_pred_image_every == 0:
                    self.model.eval()
                    for i in range(20):
                        img_path = Path(
                            self.val_loader.dataset.root
                        ) / self.val_loader.dataset.get_img_info(
                            i)['file_name']
                        cv_img = cv2.imread(str(img_path))
                        predicted_img = self.predictor.run_on_opencv_image(
                            cv_img)
                        self.writer.add_image(
                            'pred_image_{}'.format(i),
                            F.to_tensor(Image.fromarray(predicted_img)),
                            global_step=self.step)
                    self.model.train()

                if self.step % self.inference_every == 0:
                    self.model.eval()
                    try:
                        with torch.no_grad():
                            cocoEval = inference(self.model,
                                                 self.val_loader,
                                                 'coco2014',
                                                 iou_types=['bbox', 'segm'])[0]
                            bbox_map05 = cocoEval.results['bbox']['AP50']
                            bbox_mmap = cocoEval.results['bbox']['AP']
                            segm_map05 = cocoEval.results['segm']['AP50']
                            segm_mmap = cocoEval.results['segm']['AP']
                    except:
                        print('eval on coco failed')
                        bbox_map05 = -1
                        bbox_mmap = -1
                        segm_map05 = -1
                        segm_mmap = -1
                    self.model.train()
                    self.writer.add_scalar('bbox_map05', bbox_map05, self.step)
                    self.writer.add_scalar('bbox_mmap', bbox_mmap, self.step)
                    self.writer.add_scalar('segm_map05', segm_map05, self.step)
                    self.writer.add_scalar('segm_mmap', segm_mmap, self.step)

                if self.step % self.save_every == 0:
                    try:
                        self.save_state(val_loss, bbox_mmap, segm_mmap)
                    except:
                        print('save state failed')
                        self.step += 1
                        continue
                    if self.step % (10 * self.save_every) == 0:
                        try:
                            self.save_state(val_loss,
                                            bbox_mmap,
                                            segm_mmap,
                                            to_save_folder=True)
                        except:
                            print('save state failed')
                            self.step += 1
                            continue

            batch_time = time.time() - end
            end = time.time()
            meters.update(time=batch_time, data=data_time)

            eta_seconds = meters.time.global_avg * (max_iter - self.step)
            eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

            if self.step % 20 == 0 or self.step == max_iter:
                self.logger.info(
                    meters.delimiter.join([
                        "eta: {eta}",
                        "iter: {iter}",
                        "{meters}",
                        "lr: {lr:.6f}",
                        "max mem: {memory:.0f}",
                    ]).format(
                        eta=eta_string,
                        iter=self.step,
                        meters=str(meters),
                        lr=self.optimizer.param_groups[0]["lr"],
                        memory=torch.cuda.max_memory_allocated() / 1024.0 /
                        1024.0,
                    ))
            if self.step >= max_iter:
                self.save_state(val_loss,
                                bbox_mmap,
                                segm_mmap,
                                to_save_folder=True)
                return
def do_train(
    cfg,
    model,
    train_dataloader,
    val_dataloader,
    optimizer,
    lr_scheduler,
    checkpointer,
    device,
    checkpoint_period,
    test_period,
    log_period,
    arguments,
):
    logger = logging.getLogger("EfficientDet.trainer")
    logger.info("Start training")
    meters = MetricLogger(delimiter="  ")
    max_iter = len(train_dataloader)
    start_iter = arguments["iteration"]
    model.train()
    start_training_time = time.time()
    end = time.time()

    for iteration, (images, targets, _) in enumerate(train_dataloader,
                                                     start_iter):

        data_time = time.time() - end
        iteration = iteration + 1
        arguments["iteration"] = iteration

        images = images.to(device)
        targets = targets.to(device)

        loss_dict = model(images, targets)

        losses = sum(loss for loss in loss_dict.values())

        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = reduce_loss_dict(loss_dict)
        losses_reduced = sum(loss for loss in loss_dict_reduced.values())
        meters.update(loss=losses_reduced, **loss_dict_reduced)

        optimizer.zero_grad()
        # Note: If mixed precision is not used, this ends up doing nothing
        # Otherwise apply loss scaling for mixed-precision recipe
        with amp.scale_loss(losses, optimizer) as scaled_losses:
            scaled_losses.backward()
        # torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
        optimizer.step()
        # lr_scheduler.step(losses_reduced)
        lr_scheduler.step()

        batch_time = time.time() - end
        end = time.time()
        meters.update(time=batch_time, data=data_time)

        eta_seconds = meters.time.global_avg * (max_iter - iteration)
        eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

        if iteration % log_period == 0 or iteration == max_iter:
            logger.info(
                meters.delimiter.join([
                    "eta: {eta}",
                    "iter: {iter}",
                    "{meters}",
                    "lr: {lr:.6f}",
                    "max mem: {memory:.0f}",
                ]).format(
                    eta=eta_string,
                    iter=iteration,
                    meters=str(meters),
                    lr=optimizer.param_groups[0]["lr"],
                    memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
                ))
        if iteration % checkpoint_period == 0:
            checkpointer.save("model_{:07d}".format(iteration), **arguments)
        if val_dataloader is not None and (
            (test_period > 0 and iteration % test_period == 0)
                or iteration == max_iter):
            meters_val = MetricLogger(delimiter="  ")
            synchronize()
            map_05_09 = do_infer(  # The result can be used for additional logging, e. g. for TensorBoard
                model,
                val_dataloader,
                dataset_name="[Validation]",
                device=cfg.device,
                output_folder=None,
            )
            logger.info("Validation MAP 0.5:0.9 ===> {}".format(map_05_09))
            synchronize()
            model.train()
        if iteration == max_iter:
            checkpointer.save("model_final", **arguments)

    total_training_time = time.time() - start_training_time
    total_time_str = str(datetime.timedelta(seconds=total_training_time))
    logger.info("Total training time: {} ({:.4f} s / it)".format(
        total_time_str, total_training_time / (max_iter)))
    def train(self):
        # save new  configuration
        with open(os.path.join(self.output_dir, "cfg.yaml"), 'w') as f:
            x = self.cfg.dump(indent=4)
            f.write(x)

        log.info(f'New training run with configuration:\n{self.cfg}\n\n')
        train_metrics = MetricLogger(delimiter="  ")
        summary_writer = SummaryWriter(log_dir=os.path.join(self.output_dir, "summary"))

        self.model.train()
        timers = create_new_timer()
        # Initialize timing

        done = False
        while not done:
            for iteration, inputs in enumerate(self.train_loader, self.start_iteration):
                iter_time = time.time()
                data_time = iter_time - timers.batch
                inputs = to_cuda(inputs)

                out = self.model(inputs)
                loss_dict = out['loss_dict']
                loss = loss_dict["loss"]

                if torch.isnan(loss).any():
                    raise Nan_Exception()

                train_metrics.update(**loss_dict)
                summary_writer.add_scalars("train_non_smooth", train_metrics.last_item, iteration)

                batch_time = iter_time - timers.batch
                timers.batch = iter_time
                train_metrics.update(time=batch_time, data=data_time)
                eta_seconds = timers.start + self.cfg.SOLVER.MAX_TIME_SECS - iter_time
                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

                if (iter_time - timers.log > self.cfg.SOLVER.PRINT_METRICS_TIME):
                    timers.log = iter_time
                    log.info(train_metrics.delimiter.join(["eta: {eta}", "iter: {iter}", "{meters}", "lr: {lr:.6f}",
                                                           "max mem: {memory:.0f}"]).format(
                        eta=eta_string,
                        iter=iteration,
                        meters=str(train_metrics),
                        lr=self.optimizer.param_groups[0]["lr"],
                        memory=proc_id.memory_info().rss / 1e9)
                    )
                    summary_writer.add_scalars("train", train_metrics.mean, iteration)

                if iter_time - timers.checkpoint > self.cfg.SOLVER.CHECKPOINT_SECS:  # iteration % checkpoint_period == 0:
                    timers.checkpoint = iter_time
                    self.checkpointer.save("model_{:07d}".format(iteration))

                if iter_time - timers.tensorboard > self.cfg.SOLVER.TENSORBOARD_SECS or self.cfg.DEBUG:
                    timers.tensorboard = iter_time
                    summary_writer.add_scalars("train", train_metrics.mean, iteration)

                if iter_time - timers.start > self.cfg.SOLVER.MAX_TIME_SECS:
                    log.info("finished training loop in {}".format(iter_time - timers.start))
                    done = True
                    break

                if iter_time - timers.validation > self.cfg.SOLVER.VALIDATION_SECS:
                    err_dict = self.eval(iteration, summary_writer)
                    timers.validation = time.time()

                loss.backward()
                self.optimizer.step()
                self.scheduler.step()
                self.optimizer.zero_grad()

            log.info("*******  epoch done  after {}  *********".format(time.time() - timers.epoch))
            timers.epoch = time.time()
            self.start_iteration = iteration
    def eval(self, iteration, summary_writer):
        start = time.time()
        all_preds = []
        all_labels = []

        evals = []
        with torch.no_grad():
            self.model.eval()
            # self.derenderer.eval()
            val_metrics = MetricLogger(delimiter="  ")
            val_loss_logger = MetricLogger(delimiter="  ")
            for i, inputs in enumerate(self.val_loader, iteration):
                # data_time = time.time() - last_batch_time

                if torch.cuda.is_available():
                    inputs = to_cuda(inputs)

                output = self.model(inputs, match=True)
                loss_dict = output["loss_dict"]
                is_possible = inputs['is_possible']
                magic_penalty = output['magic_penalty']

                for i in range(len(magic_penalty)):
                    frame = {}
                    frame['is_possible'] = bool(is_possible[i])
                    frame['inverse_likelihood'] = float(magic_penalty[i])
                    evals.append(frame)

                # target = inputs['targets']
                # output = output['output']
                # is_possible = inputs['is_possible']

                # loc_x_gt = target['location_x']
                # loc_y_gt = target['location_y']
                # loc_z_gt = target['location_z']

                # output_x = output['location_x'].squeeze()
                # output_y = output['location_y'].squeeze()
                # output_z = output['location_z'].squeeze()
                # existance = target['existance'][:, 1:]

                # loss_trans_x = torch.pow(output_x - loc_x_gt[:, 1:], 2) * existance
                # loss_trans_y = torch.pow(output_y - loc_y_gt[:, 1:], 2) * existance
                # loss_trans_z = torch.pow(output_z - loc_z_gt[:, 1:], 2) * existance

                # loss_trans_x = loss_trans_x.mean(dim=2).mean(dim=1)
                # loss_trans_y = loss_trans_y.mean(dim=2).mean(dim=1)
                # loss_trans_z = loss_trans_z.mean(dim=2).mean(dim=1)

                # loss = loss_trans_z + loss_trans_y + loss_trans_x
                # energy_pos = loss[is_possible]
                # energy_neg = loss[~is_possible]

                # energy_pos = energy_pos.detach().cpu().numpy()
                # energy_neg = energy_neg.detach().cpu().numpy()

                # for i in range(energy_pos.shape[0]):
                #     frame = {}
                #     frame['is_possible'] = True
                #     frame['likelihood'] = float(energy_pos[i])
                #     evals.append(frame)

                # for i in range(energy_neg.shape[0]):
                #     frame = {}
                #     frame['is_possible'] = False
                #     frame['likelihood'] = float(energy_neg[i])
                #     evals.append(frame)

                # print("possible: ", energy_pos.mean())
                # print("not possible: ", energy_neg.mean())




                val_loss_logger.update(**loss_dict)
                # summary_writer.add_scalars("val_non_smooth", val_loss_logger.last_item, i)

                # all_preds.append({k: v.cpu().numpy() for k, v in output["output"].items()})
                # all_labels.append({k: v.cpu().numpy() for k, v in inputs["attributes"].items()})
                # all_labels = self.attributes.cat_by_key(all_labels, inputs["attributes"])
                # all_preds = self.attributes.cat_by_key(all_preds, output['output'])

                # batch_time = time.time() - last_batch_time
                # val_metrics.update(time=batch_time, data=data_time)
                # if time.time() - start > self.cfg.SOLVER.VALIDATION_MAX_SECS:
                #     raise Val_Too_Long

            # all_preds, all_labels = map(lambda l: {k: np.concatenate([a[k] for a in l]) for k in l[0].keys()},
            #                             [all_preds, all_labels])
            # all_preds = {k: np.concatenate([a[k] for a in all_preds]) for k in all_preds[0].keys()}
            # all_labels = {k: np.concatenate([a[k] for a in all_labels]) for k in all_labels[0].keys()}
            # err_dict = self.attributes.pred_error(all_preds, all_labels)
            # val_metrics.update(**err_dict)
            # log.info(val_metrics.delimiter.join(["VALIDATION", "iter: {iter}", "{meters}"])
            #          .format(iter=iteration, meters=str(val_metrics)))
            log.info(val_metrics.delimiter.join(["VALIDATION", "iter: {iter}", "{meters}"])
                     .format(iter=iteration, meters=str(val_loss_logger)))
            if summary_writer is not None:
                # summary_writer.add_scalars("val_error", val_metrics.mean, iteration)
                summary_writer.add_scalars("val", val_loss_logger.mean, iteration)
            # self.derenderer.train()
        json.dump(evals, open("output.json", "w"))
        self.model.train()
        return None
Esempio n. 13
0
def do_train(
        cfg,
        model,
        train_loader,
        val_loader,
        optimizer,
        scheduler,
        criterion,
        checkpointer,
        device,
        checkpoint_period,
        arguments,
        logger
):
    logger.info("Start training")
    meters = MetricLogger(delimiter="  ")
    max_iter = len(train_loader)

    start_iter = arguments["iteration"]
    best_iteration = -1
    best_recall = 0

    start_training_time = time.time()
    end = time.time()
    for iteration, (images, targets) in enumerate(train_loader, start_iter):

        if iteration % cfg.VALIDATION.VERBOSE == 0 or iteration == max_iter:
            model.eval()
            logger.info('Validation')
            labels = val_loader.dataset.label_list
            labels = np.array([int(k) for k in labels])
            feats = feat_extractor(model, val_loader, logger=logger)
            #print(feats.shape)

            ret_metric = RetMetric(feats=feats, labels=labels)
            recall_curr = ret_metric.recall_k(1)
            reamp_var  = ret_metric.re_map(5)

            #print(reamp_var)

            if recall_curr > best_recall:
                best_recall = recall_curr
                best_iteration = iteration
                logger.info(f'Best iteration {iteration}: recall@1: {best_recall:.3f}')
                print(f"best_model")
               # print('songkun')
                checkpointer.save(f"best_model")
            else:
                logger.info(f'Recall@1 at iteration {iteration:06d}: {recall_curr:.3f}')

        model.train()
        model.apply(set_bn_eval)

        data_time = time.time() - end
        iteration = iteration + 1
        arguments["iteration"] = iteration

        scheduler.step()

        images = images.to(device)
        targets = torch.stack([target.to(device) for target in targets])
       # print("pppppppp")
        feats = model(images)
      #  print("After feat = model(images), the size of feat:")
      #  print(feats.shape)
        loss = criterion(feats, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        batch_time = time.time() - end
        end = time.time()
        meters.update(time=batch_time, data=data_time, loss=loss.item())

        eta_seconds = meters.time.global_avg * (max_iter - iteration)
        eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

        if iteration % 20 == 0 or iteration == max_iter:
            logger.info(
                meters.delimiter.join(
                    [
                        "eta: {eta}",
                        "iter: {iter}",
                        "{meters}",
                        "lr: {lr:.6f}",
                        "max mem: {memory:.1f} GB",
                    ]
                ).format(
                    eta=eta_string,
                    iter=iteration,
                    meters=str(meters),
                    lr=optimizer.param_groups[0]["lr"],
                    memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0 / 1024.0,
                )
            )

        if iteration % checkpoint_period == 0:
            checkpointer.save("model_{:06d}".format(iteration))

    total_training_time = time.time() - start_training_time
    total_time_str = str(datetime.timedelta(seconds=total_training_time))
    logger.info(
        "Total training time: {} ({:.4f} s / it)".format(
            total_time_str, total_training_time / (max_iter)
        )
    )

    logger.info(f"Best iteration: {best_iteration :06d} | best recall {best_recall} ")
Esempio n. 14
0
    if args.consistency_type == 'mse':
        consistency_criterion = losses.softmax_mse_loss
    elif args.consistency_type == 'kl':
        consistency_criterion = losses.softmax_kl_loss
    else:
        assert False, args.consistency_type

    writer = SummaryWriter(snapshot_path + '/log')

    iter_num = args.global_step
    lr_ = base_lr
    model.train()

    #train
    for epoch in range(args.start_epoch, args.epochs):
        meters_loss = MetricLogger(delimiter="  ")
        meters_loss_classification = MetricLogger(delimiter="  ")
        meters_loss_consistency = MetricLogger(delimiter="  ")
        meters_loss_consistency_relation = MetricLogger(delimiter="  ")
        time1 = time.time()
        iter_max = len(train_dataloader)
        for i, (_, _, (image_batch, ema_image_batch),
                label_batch) in enumerate(train_dataloader):
            time2 = time.time()
            # print('fetch data cost {}'.format(time2-time1))
            image_batch, ema_image_batch, label_batch = image_batch.cuda(
            ), ema_image_batch.cuda(), label_batch.cuda()
            # unlabeled_image_batch = ema_image_batch[labeled_bs:]

            # noise1 = torch.clamp(torch.randn_like(image_batch) * 0.1, -0.1, 0.1)
            # noise2 = torch.clamp(torch.randn_like(ema_image_batch) * 0.1, -0.1, 0.1)
Esempio n. 15
0
def do_train(
    model,
    data_loader,
    optimizer,
    scheduler,
    device,
    args,
    cfg,
):
    meters = MetricLogger(delimiter="  ")
    start_iter = args["iteration"]
    model.train()
    start_training_time = time.time()
    end = time.time()
    # 这里走了一个epoch,和自己规划的iter不一样
    iter_count = 0
    max_acc = max_acc_avg = 0
    data_loader_val = make_data_loader(cfg, is_train=False)
    for epoch in range(cfg.SOLVER.EPOCH):
        for iteration, (images, levels) in enumerate(data_loader, start_iter):
            # images = to_image_list(images)
            data_time = time.time() - end
            iter_count += 1
            scheduler.step()
            images = images.tensors
            images = images.to(device)
            levels = torch.tensor(levels).cuda()
            loss_dict, prediction = model(images, levels)
            # 计算accuracy
            prediction = torch.argmax(prediction, dim=1)
            res = prediction == levels
            res = res.tolist()
            correct = 0
            for i in range(len(res)):
                if res[i]:
                    correct += 1
            train_acc = correct / len(res)

            losses = sum(loss for loss in loss_dict.values())
            meters.update(loss=losses, train_acc=train_acc, **loss_dict)

            optimizer.zero_grad()
            losses.backward()
            optimizer.step()

            batch_time = time.time() - end
            end = time.time()
            meters.update(time=batch_time, data=data_time)

            if iter_count % 10 == 0:
                print("epoch: {epoch}, iter: {iter}, {meters}, lr: {lr:.6f}".
                      format(epoch=epoch,
                             iter=iter_count,
                             meters=str(meters),
                             lr=optimizer.param_groups[0]["lr"]))
        if epoch >= 0 and epoch % 2 == 1:
            # 训练时测试,遇到效果好的保存模型,先测试再打开train,隔两次测试一次
            # TODO:测试other
            print('epoch ', epoch, ': testing the model')
            acc_avg, accuracy = inference(model, data_loader_val, device)
            if accuracy > max_acc:
                max_acc = accuracy
                torch.save(
                    model.state_dict(),
                    os.path.join(cfg.OUTPUT_DIR,
                                 'model' + str(epoch) + '.pth'))
                print("saving model")
            elif acc_avg > max_acc_avg:
                max_acc_avg = acc_avg
                torch.save(
                    model.state_dict(),
                    os.path.join(cfg.OUTPUT_DIR,
                                 'model' + str(epoch) + '.pth'))
                print("saving model")
            model.train()

    # 计算时间
    total_training_time = time.time() - start_training_time
    total_time_str = str(datetime.timedelta(seconds=total_training_time))
    print("Total training time: {} ({:.4f} s / epoch)".format(
        total_time_str, total_training_time / cfg.SOLVER.EPOCH))
    torch.save(model.state_dict(),
               os.path.join(cfg.OUTPUT_DIR, 'final_model.pth'))
    return model
Esempio n. 16
0
def do_train(cfg, model, data_loader, optimizer, scheduler,
             criterion, checkpointer, device, arguments,
             tblogger, data_loader_val, distributed):
    logger = logging.getLogger('eve.' + __name__)
    meters = MetricLogger(delimiter="  ")
    max_iter = len(data_loader)
    start_iter = arguments['iteration']
    model.train()
    start_training_time = time.time()
    end = time.time()
    logger.info("Start training")
    logger.info("Arguments: {}".format(arguments))

    for iteration, batch in enumerate(data_loader, start_iter):
        model.train()
        data_time = time.time() - end
        iteration = iteration + 1
        arguments['iteration'] = iteration

        # FIXME: for eve, modify dataloader
        locs, feats, targets, _ = batch
        inputs = ME.SparseTensor(feats, coords=locs).to(device)
        targets = targets.to(device, non_blocking=True).long()
        out = model(inputs, y=targets)

        if len(out) == 2:  # minkunet_eve
            outputs, match = out
        else:
            outputs = out
        loss = criterion(outputs, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if len(out) == 2:  # FIXME
            loss_dict = dict(loss=loss, match_acc=match[0], match_time=match[1])
        else:
            loss_dict = dict(loss=loss)
        loss_dict_reduced = reduce_dict(loss_dict)
        meters.update(**loss_dict_reduced)

        batch_time = time.time() - end
        end = time.time()
        meters.update(time=batch_time, data_time=data_time)
        eta_seconds = meters.time.global_avg * (max_iter - iteration)
        eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

        if tblogger is not None:
            for name, meter in meters.meters.items():
                if 'time' in name:
                    tblogger.add_scalar(
                        'other/' + name, meter.median, iteration)
                else:
                    tblogger.add_scalar(
                        'train/' + name, meter.median, iteration)
            tblogger.add_scalar(
                'other/lr', optimizer.param_groups[0]['lr'], iteration)

        if iteration % cfg.SOLVER.LOG_PERIOD == 0 \
                or iteration == max_iter \
                or iteration == 0:
            logger.info(
                meters.delimiter.join(
                    [
                        "train eta: {eta}",
                        "iter: {iter}",
                        "{meters}",
                        "lr: {lr:.6f}",
                        "max mem: {memory:.0f}",
                    ]
                ).format(
                    eta=eta_string,
                    iter=iteration,
                    meters=str(meters),
                    lr=optimizer.param_groups[0]['lr'],
                    memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
                )
            )

        scheduler.step()

        if iteration % cfg.SOLVER.CHECKPOINT_PERIOD == 0:
            checkpointer.save('model_{:06d}'.format(iteration), **arguments)

        if iteration % 100 == 0:
            checkpointer.save('model_last', **arguments)

        if iteration == max_iter:
            checkpointer.save('model_final', **arguments)

        if iteration % cfg.SOLVER.EVAL_PERIOD == 0 \
                or iteration == max_iter:
            metrics = val_in_train(
                model,
                criterion,
                cfg.DATASETS.VAL,
                data_loader_val,
                tblogger,
                iteration,
                checkpointer,
                distributed)

            if metrics is not None:
                if arguments['best_iou'] < metrics['iou']:
                    arguments['best_iou'] = metrics['iou']
                    logger.info('best_iou: {}'.format(arguments['best_iou']))
                    checkpointer.save('model_best', **arguments)
                else:
                    logger.info('best_iou: {}'.format(arguments['best_iou']))

            if tblogger is not None:
                tblogger.add_scalar(
                    'val/best_iou', arguments['best_iou'], iteration)

            model.train()

            end = time.time()

    total_training_time = time.time() - start_training_time
    total_time_str = str(datetime.timedelta(seconds=total_training_time))
    logger.info(
        "Total training time: {} ({:.4f} s / it)".format(
            total_time_str, total_training_time / (max_iter)
        )
    )
Esempio n. 17
0
def train(cfg, args):
    train_set = DatasetCatalog.get(cfg.DATASETS.TRAIN, args)
    val_set = DatasetCatalog.get(cfg.DATASETS.VAL, args)
    train_loader = DataLoader(train_set,
                              cfg.SOLVER.IMS_PER_BATCH,
                              num_workers=cfg.DATALOADER.NUM_WORKERS,
                              shuffle=True)
    val_loader = DataLoader(val_set,
                            cfg.SOLVER.IMS_PER_BATCH,
                            num_workers=cfg.DATALOADER.NUM_WORKERS,
                            shuffle=True)

    gpu_ids = [_ for _ in range(torch.cuda.device_count())]
    model = build_model(cfg)
    model.to("cuda")
    model = torch.nn.parallel.DataParallel(
        model, gpu_ids) if not args.debug else model

    logger = logging.getLogger("train_logger")
    logger.info("Start training")
    train_metrics = MetricLogger(delimiter="  ")
    max_iter = cfg.SOLVER.MAX_ITER
    output_dir = cfg.OUTPUT_DIR

    optimizer = make_optimizer(cfg, model)
    scheduler = make_lr_scheduler(cfg, optimizer)
    checkpointer = Checkpointer(model, optimizer, scheduler, output_dir,
                                logger)
    start_iteration = checkpointer.load() if not args.debug else 0

    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
    validation_period = cfg.SOLVER.VALIDATION_PERIOD
    summary_writer = SummaryWriter(log_dir=os.path.join(output_dir, "summary"))
    visualizer = train_set.visualizer(cfg.VISUALIZATION)(summary_writer)

    model.train()
    start_training_time = time.time()
    last_batch_time = time.time()

    for iteration, inputs in enumerate(cycle(train_loader), start_iteration):
        data_time = time.time() - last_batch_time
        iteration = iteration + 1
        scheduler.step()

        inputs = to_cuda(inputs)
        outputs = model(inputs)

        loss_dict = gather_loss_dict(outputs)
        loss = loss_dict["loss"]
        train_metrics.update(**loss_dict)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        batch_time = time.time() - last_batch_time
        last_batch_time = time.time()
        train_metrics.update(time=batch_time, data=data_time)

        eta_seconds = train_metrics.time.global_avg * (max_iter - iteration)
        eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

        if iteration % 20 == 0 or iteration == max_iter:
            logger.info(
                train_metrics.delimiter.join([
                    "eta: {eta}", "iter: {iter}", "{meters}", "lr: {lr:.6f}",
                    "max mem: {memory:.0f}"
                ]).format(eta=eta_string,
                          iter=iteration,
                          meters=str(train_metrics),
                          lr=optimizer.param_groups[0]["lr"],
                          memory=torch.cuda.max_memory_allocated() / 1024.0 /
                          1024.0))
            summary_writer.add_scalars("train", train_metrics.mean, iteration)

        if iteration % 100 == 0:
            visualizer.visualize(inputs, outputs, iteration)

        if iteration % checkpoint_period == 0:
            checkpointer.save("model_{:07d}".format(iteration))

        if iteration % validation_period == 0:
            with torch.no_grad():
                val_metrics = MetricLogger(delimiter="  ")
                for i, inputs in enumerate(val_loader):
                    data_time = time.time() - last_batch_time

                    inputs = to_cuda(inputs)
                    outputs = model(inputs)

                    loss_dict = gather_loss_dict(outputs)
                    val_metrics.update(**loss_dict)

                    batch_time = time.time() - last_batch_time
                    last_batch_time = time.time()
                    val_metrics.update(time=batch_time, data=data_time)

                    if i % 20 == 0 or i == cfg.SOLVER.VALIDATION_LIMIT:
                        logger.info(
                            val_metrics.delimiter.join([
                                "VALIDATION", "eta: {eta}", "iter: {iter}",
                                "{meters}"
                            ]).format(eta=eta_string,
                                      iter=iteration,
                                      meters=str(val_metrics)))

                    if i == cfg.SOLVER.VALIDATION_LIMIT:
                        summary_writer.add_scalars("val", val_metrics.mean,
                                                   iteration)
                        break
        if iteration == max_iter:
            break

    checkpointer.save("model_{:07d}".format(max_iter))
    total_training_time = time.time() - start_training_time
    total_time_str = str(datetime.timedelta(seconds=total_training_time))
    logger.info("Total training time: {} ({:.4f} s / it)".format(
        total_time_str, total_training_time / (max_iter)))
def do_train(cfg, model, data_loader_train, data_loader_val, optimizer,
             scheduler, checkpointer, device, arguments, summary_writer):
    # get logger
    logger = logging.getLogger(cfg.NAME)
    logger.info("Start training ...")
    logger.info("Size of training dataset: %s" %
                (data_loader_train.dataset.__len__()))
    logger.info("Size of validation dataset: %s" %
                (data_loader_val.dataset.__len__()))

    model.train()

    meters = MetricLogger(delimiter="  ")
    max_iter = len(data_loader_train)
    start_iter = arguments["iteration"]
    start_training_time = time.time()
    end = time.time()
    bar = TqdmBar(data_loader_train,
                  start_iter,
                  get_rank(),
                  data_loader_train.__len__(),
                  description="Training",
                  use_bar=cfg.USE_BAR)

    for iteration, record in bar.bar:
        data_time = time.time() - end
        iteration += 1
        arguments["iteration"] = iteration
        record = move_to_device(record, device)

        loss, _ = model(record)
        optimizer.zero_grad()
        loss["total_loss"].backward()
        optimizer.step()
        scheduler.step()

        # reduce losses over all GPUs for logging purposes
        loss_reduced = {key: value.cpu().item() for key, value in loss.items()}
        meters.update(**loss_reduced)

        batch_time = time.time() - end
        end = time.time()
        meters.update(time=batch_time, data=data_time)

        eta_seconds = meters.time.global_avg * (max_iter - iteration)
        eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
        lr = optimizer.param_groups[0]["lr"]
        bar.set_postfix({"lr": lr, "total_loss": loss_reduced["total_loss"]})

        if iteration % cfg.SOLVER.LOGGER_PERIOD == 0 or iteration == max_iter:
            bar.clear(nolock=True)
            logger.info(
                meters.delimiter.join([
                    "iter: {iter:06d}",
                    "lr: {lr:.6f}",
                    "{meters}",
                    "eta: {eta}",
                    "mem: {memory:.0f}",
                ]).format(
                    iter=iteration,
                    lr=lr,
                    meters=str(meters),
                    eta=eta_string,
                    memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
                ))

            if summary_writer:
                write_summary(summary_writer,
                              iteration,
                              record=loss,
                              group='Losses')
                write_summary(summary_writer,
                              iteration,
                              record={'lr': lr},
                              group='LR')

        if iteration % cfg.SOLVER.CHECKPOINT_PERIOD == 0:
            bar.clear(nolock=True)
            checkpointer.save("model_{:07d}".format(iteration), **arguments)
            if data_loader_val is not None:
                do_evaluation(cfg, model, data_loader_val, device, arguments,
                              summary_writer)

    checkpointer.save("model_final", **arguments)

    bar.close()
    total_training_time = time.time() - start_training_time
    total_time_str = str(datetime.timedelta(seconds=total_training_time))
    logger.info("Total training time: {} ({:.4f} s / it)".format(
        total_time_str, total_training_time / max_iter))
Esempio n. 19
0
def train(args):
    try:
        model = nets[args.net](args.margin, args.omega, args.use_hardtriplet)
        model.to(args.device)
    except Exception as e:
        logger.error("Initialize {} error: {}".format(args.net, e))
        return
    logger.info("Training {}.".format(args.net))

    optimizer = make_optimizer(args, model)
    scheduler = make_scheduler(args, optimizer)

    if args.device != torch.device("cpu"):
        amp_opt_level = 'O1' if args.use_amp else 'O0'
        model, optimizer = amp.initialize(model, optimizer, opt_level=amp_opt_level)

    arguments = {}
    arguments.update(vars(args))
    arguments["itr"] = 0
    checkpointer = Checkpointer(model, 
                                optimizer=optimizer, 
                                scheduler=scheduler,
                                save_dir=args.out_dir, 
                                save_to_disk=True)
    ## load model from pretrained_weights or training break_point.
    extra_checkpoint_data = checkpointer.load(args.pretrained_weights)
    arguments.update(extra_checkpoint_data)
    
    batch_size = args.batch_size
    fashion = FashionDataset(item_num=args.iteration_num*batch_size)
    dataloader = DataLoader(dataset=fashion, shuffle=True, num_workers=8, batch_size=batch_size)

    model.train()
    meters = MetricLogger(delimiter=", ")
    max_itr = len(dataloader)
    start_itr = arguments["itr"] + 1
    itr_start_time = time.time()
    training_start_time = time.time()
    for itr, batch_data in enumerate(dataloader, start_itr):
        batch_data = (bd.to(args.device) for bd in batch_data)
        loss_dict = model.loss(*batch_data)
        optimizer.zero_grad()
        if args.device != torch.device("cpu"):
            with amp.scale_loss(loss_dict["loss"], optimizer) as scaled_losses:
                scaled_losses.backward()
        else:
            loss_dict["loss"].backward()
        optimizer.step()
        scheduler.step()

        arguments["itr"] = itr
        meters.update(**loss_dict)
        itr_time = time.time() - itr_start_time
        itr_start_time = time.time()
        meters.update(itr_time=itr_time)
        if itr % 50 == 0:
            eta_seconds = meters.itr_time.global_avg * (max_itr - itr)
            eta = str(datetime.timedelta(seconds=int(eta_seconds)))
            logger.info(
                meters.delimiter.join(
                    [
                        "itr: {itr}/{max_itr}",
                        "lr: {lr:.7f}",
                        "{meters}",
                        "eta: {eta}\n",
                    ]
                ).format(
                    itr=itr,
                    lr=optimizer.param_groups[0]["lr"],
                    max_itr=max_itr,
                    meters=str(meters),
                    eta=eta,
                )
            )

        ## save model
        if itr % args.checkpoint_period == 0:
            checkpointer.save("model_{:07d}".format(itr), **arguments)
        if itr == max_itr:
            checkpointer.save("model_final", **arguments)
            break

    training_time = time.time() - training_start_time
    training_time = str(datetime.timedelta(seconds=int(training_time)))
    logger.info("total training time: {}".format(training_time))
Esempio n. 20
0
def train(cfg):
    device = torch.device(cfg.DEVICE)
    arguments = {}
    arguments["epoch"] = 0
    if not cfg.DATALOADER.BENCHMARK:
        model = Modelbuilder(cfg)
        print(model)
        model.to(device)
        model.float()
        optimizer, scheduler = make_optimizer(cfg, model)
        checkpointer = Checkpointer(model=model,
                                    optimizer=optimizer,
                                    scheduler=scheduler,
                                    save_dir=cfg.OUTPUT_DIR)
        extra_checkpoint_data = checkpointer.load(
            cfg.WEIGHTS,
            prefix=cfg.WEIGHTS_PREFIX,
            prefix_replace=cfg.WEIGHTS_PREFIX_REPLACE,
            loadoptimizer=cfg.WEIGHTS_LOAD_OPT)
        arguments.update(extra_checkpoint_data)
        model.train()

    logger = setup_logger("trainer", cfg.FOLDER_NAME)
    if cfg.TENSORBOARD.USE:
        writer = SummaryWriter(cfg.FOLDER_NAME)
    else:
        writer = None
    meters = MetricLogger(writer=writer)
    start_training_time = time.time()
    end = time.time()
    start_epoch = arguments["epoch"]
    max_epoch = cfg.SOLVER.MAX_EPOCHS

    if start_epoch == max_epoch:
        logger.info("Final model exists! No need to train!")
        test(cfg, model)
        return

    data_loader = make_data_loader(
        cfg,
        is_train=True,
    )
    size_epoch = len(data_loader)
    max_iter = size_epoch * max_epoch
    logger.info("Start training {} batches/epoch".format(size_epoch))

    for epoch in range(start_epoch, max_epoch):
        arguments["epoch"] = epoch
        #batchcnt = 0
        for iteration, batchdata in enumerate(data_loader):
            cur_iter = size_epoch * epoch + iteration
            data_time = time.time() - end

            batchdata = {
                k: v.to(device) if isinstance(v, torch.Tensor) else v
                for k, v in batchdata.items()
            }

            if not cfg.DATALOADER.BENCHMARK:
                loss_dict, metric_dict = model(batchdata)
                # print(loss_dict, metric_dict)
                optimizer.zero_grad()
                loss_dict['loss'].backward()
                optimizer.step()

            batch_time = time.time() - end
            end = time.time()

            meters.update(time=batch_time, data=data_time, iteration=cur_iter)

            if cfg.DATALOADER.BENCHMARK:
                logger.info(
                    meters.delimiter.join([
                        "iter: {iter}",
                        "{meters}",
                    ]).format(
                        iter=iteration,
                        meters=str(meters),
                    ))
                continue

            eta_seconds = meters.time.global_avg * (max_iter - cur_iter)
            eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

            if iteration % cfg.LOG_FREQ == 0:
                meters.update(iteration=cur_iter, **loss_dict)
                meters.update(iteration=cur_iter, **metric_dict)
                logger.info(
                    meters.delimiter.join([
                        "eta: {eta}",
                        "epoch: {epoch}",
                        "iter: {iter}",
                        "{meters}",
                        "lr: {lr:.6f}",
                        # "max mem: {memory:.0f}",
                    ]).format(
                        eta=eta_string,
                        epoch=epoch,
                        iter=iteration,
                        meters=str(meters),
                        lr=optimizer.param_groups[0]["lr"],
                        # memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
                    ))
        #UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`.  Failure to do this will result in PyTorch skipping the first value of the learning rate schedule.See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
        scheduler.step()

        if (epoch + 1) % cfg.SOLVER.CHECKPOINT_PERIOD == 0:
            arguments["epoch"] += 1
            checkpointer.save("model_{:03d}".format(epoch), **arguments)
        if epoch == max_epoch - 1:
            arguments['epoch'] = max_epoch
            checkpointer.save("model_final", **arguments)

            total_training_time = time.time() - start_training_time
            total_time_str = str(
                datetime.timedelta(seconds=total_training_time))
            logger.info("Total training time: {} ({:.4f} s / epoch)".format(
                total_time_str,
                total_training_time / (max_epoch - start_epoch)))
        if epoch == max_epoch - 1 or ((epoch + 1) % cfg.EVAL_FREQ == 0):
            results = test(cfg, model)
            meters.update(is_train=False, iteration=cur_iter, **results)
Esempio n. 21
0
def train_one_epoch(model, optimizer: torch.optim, data_loader: DataLoader,
                    criterion: torch.nn.modules.loss, device: torch.device,
                    epoch: int, print_freq: int):
    """
    Method to train Plant model one time
    Args:
        model (plant_model): model to train
        optimizer (torch.optimizer): optimizer used
        criterion (loss): loss
        data_loader (torch.utils.data.DataLoader): data loader to test model on
        device (torch.device): device to use, either device("cpu") or device("cuda")
        epoch (int = None): state epoch of the current training if there is one
        print_freq (int = None): print frequency of log writer
        writer (SummaryWriter = None): set a writer if you want to write log file in tensorboard
    """
    # Set model to train mode
    model.train()
    # Define metric logger parameters
    metric_logger = MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr',
                            SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Epoch: [{}]'.format(epoch)

    metric = ComputeMetrics(n_classes=4)

    running_loss, epoch_loss = 0, 0

    mixup_data = MixupData()
    mixup_criterion = MixupCriterion()

    # Iterate over dataloader to train model
    for images, labels in metric_logger.log_every(data_loader, print_freq,
                                                  epoch, header):
        images = images.to(device)
        labels = labels.to(device)

        mixup = False
        # Compute loss
        if mixup:
            inputs, targets_a, targets_b, lam = mixup_data(images, labels)
            inputs, targets_a, targets_b = map(torch.tensor,
                                               (inputs, targets_a, targets_b))

            outputs = model(inputs)

            loss = mixup_criterion(criterion, outputs, targets_a, targets_b,
                                   lam)
        else:
            outputs = model(images)
            loss = criterion(outputs, labels)
        running_loss += loss.item()

        # Process backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Compute ROC AUC
        metric.update(outputs, labels)

        # metrics = compute_roc_auc(outputs, labels)
        metric_logger.update(loss=loss)
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])

    epoch_metric = metric.get_auc_roc()
    epoch_metric["loss"] = running_loss / len(data_loader.dataset)

    print(epoch_metric)
    return epoch_metric