def evaluate_batch(self, batch: TorchData,
                       model: nn.Module) -> Dict[str, Any]:
        images, targets = batch
        output = model(list(images), copy.deepcopy(list(targets)))
        sum_iou = 0
        num_boxes = 0
        # Instantiate the Tensorboard writer and set the log_dir to /tmp/tensorboard where Determined looks for events
        writer = SummaryWriter(log_dir="/tmp/tensorboard")

        # Our eval metric is the average best IoU (across all predicted
        # pedestrian bounding boxes) per target pedestrian.  Given predicted
        # and target bounding boxes, IoU is the area of the intersection over
        # the area of the union.
        for idx, target in enumerate(targets):
            # Filter out overlapping bounding box predictions based on
            # non-maximum suppression (NMS)
            predicted_boxes = output[idx]["boxes"]
            prediction_scores = output[idx]["scores"]
            keep_indices = torchvision.ops.nms(predicted_boxes,
                                               prediction_scores, 0.1)
            predicted_boxes = torch.index_select(predicted_boxes, 0,
                                                 keep_indices)
            prediction_scores = torch.index_select(prediction_scores, 0,
                                                   keep_indices)

            # Tally IoU with respect to the ground truth target boxes
            target_boxes = target["boxes"]
            boxes_iou = torchvision.ops.box_iou(target_boxes, predicted_boxes)
            sum_iou += sum(max(iou_result) for iou_result in boxes_iou)
            num_boxes += len(target_boxes)

            # boxes are ordered by confidence, so get the top 5 bounding boxes and write out to Tensorboard
            # new_predicted_boxes = output[idx]["boxes"][:5]
            threshold = 0.7
            cutoff = 0
            for i, score in enumerate(prediction_scores):
                if score < threshold:
                    break
                cutoff = i
            new_predicted_boxes = output[idx]["boxes"][:cutoff]
            writer.add_image_with_boxes("step_" + str(self.current_step),
                                        images[idx], predicted_boxes)

        writer.close()

        return {"val_avg_iou": sum_iou / num_boxes}
예제 #2
0
    class PytorchTBWriter(object):
        def __init__(self, *inputs):
            args    = inputs[0]
            log_dir = None 
            if len(inputs) == 2:
                self.log_dir = inputs[1]
             
            if self.log_dir is None:
                directory = os.path.join(args.log_dir, args.dataset, args.checkname)

                runs = sorted(glob.glob(os.path.join(directory, 'experiment_*')))
                exist_run_ids = sorted([int(r.split('_')[-1]) for r in runs])
                run_id = exist_run_ids[-1] + 1 if runs else 0

                self.log_dir = os.path.join(directory, 'experiment_{}'.format(str(run_id)))
                if not os.path.exists(self.log_dir):
                    os.makedirs(self.log_dir)

            self.writer = SummaryWriter(self.log_dir)

        def add_images_with_bboxes(self, tag, images, bbox, step, labels=None, dataformats='CHW'):
            '''bbox:  N x 4 (xmin, ymin, xmax, ymax), absolute values'''
            self.writer.add_image_with_boxes(tag, images, bbox, step, dataformats=dataformats)
            self.writer.flush()
           
        def list_of_scalars_summary(self, tag_value_pairs, step):
            for tag, value in tag_value_pairs:
                self.writer.add_scalar(tag, value, step)
                self.writer.flush()

        def add_image(self, tag, image, step):
            self.writer.add_image(tag, image, step)
            self.writer.flush()

        def add_images(self, tag, images, step):
            self.writer.add_images(tag, images, step)
            self.writer.flush()
예제 #3
0
class SummaryWriter:
    def __init__(self, logdir, flush_secs=120):

        self.writer = TensorboardSummaryWriter(
            log_dir=logdir,
            purge_step=None,
            max_queue=10,
            flush_secs=flush_secs,
            filename_suffix='')

        self.global_step = None
        self.active = True

        # ------------------------------------------------------------------------
        # register add_* and set_* functions in summary module on instantiation
        # ------------------------------------------------------------------------
        this_module = sys.modules[__name__]
        list_of_names = dir(SummaryWriter)
        for name in list_of_names:

            # add functions (without the 'add' prefix)
            if name.startswith('add_'):
                setattr(this_module, name[4:], getattr(self, name))

            #  set functions
            if name.startswith('set_'):
                setattr(this_module, name, getattr(self, name))

    def set_global_step(self, value):
        self.global_step = value

    def set_active(self, value):
        self.active = value

    def add_audio(self, tag, snd_tensor, global_step=None, sample_rate=44100, walltime=None):
        if self.active:
            global_step = self.global_step if global_step is None else global_step
            self.writer.add_audio(
                tag, snd_tensor, global_step=global_step, sample_rate=sample_rate, walltime=walltime)

    def add_custom_scalars(self, layout):
        if self.active:
            self.writer.add_custom_scalars(layout)

    def add_custom_scalars_marginchart(self, tags, category='default', title='untitled'):
        if self.active:
            self.writer.add_custom_scalars_marginchart(tags, category=category, title=title)

    def add_custom_scalars_multilinechart(self, tags, category='default', title='untitled'):
        if self.active:
            self.writer.add_custom_scalars_multilinechart(tags, category=category, title=title)

    def add_embedding(self, mat, metadata=None, label_img=None, global_step=None,
                      tag='default', metadata_header=None):
        if self.active:
            global_step = self.global_step if global_step is None else global_step
            self.writer.add_embedding(
                mat, metadata=metadata, label_img=label_img, global_step=global_step,
                tag=tag, metadata_header=metadata_header)

    def add_figure(self, tag, figure, global_step=None, close=True, walltime=None):
        if self.active:
            global_step = self.global_step if global_step is None else global_step
            self.writer.add_figure(
                tag, figure, global_step=global_step, close=close, walltime=walltime)

    def add_graph(self, model, input_to_model=None, verbose=False):
        if self.active:
            self.writer.add_graph(model, input_to_model=input_to_model, verbose=verbose)

    def add_histogram(self, tag, values, global_step=None, bins='tensorflow', walltime=None, max_bins=None):
        if self.active:
            global_step = self.global_step if global_step is None else global_step
            self.writer.add_histogram(
                tag, values, global_step=global_step, bins=bins,
                walltime=walltime, max_bins=max_bins)

    def add_histogram_raw(self, tag, min, max, num, sum, sum_squares,
                          bucket_limits, bucket_counts, global_step=None,
                          walltime=None):
        if self.active:
            global_step = self.global_step if global_step is None else global_step
            self.writer.add_histogram_raw(
                tag, min=min, max=max, num=num, sum=sum, sum_squares=sum_squares,
                bucket_limits=bucket_limits, bucket_counts=bucket_counts,
                global_step=global_step, walltime=walltime)

    def add_image(self, tag, img_tensor, global_step=None, walltime=None, dataformats='CHW'):
        if self.active:
            global_step = self.global_step if global_step is None else global_step
            self.writer.add_image(
                tag, img_tensor, global_step=global_step, walltime=walltime, dataformats=dataformats)

    def add_image_with_boxes(self, tag, img_tensor, box_tensor, global_step=None,
                             walltime=None, rescale=1, dataformats='CHW'):
        if self.active:
            global_step = self.global_step if global_step is None else global_step
            self.writer.add_image_with_boxes(
                tag, img_tensor, box_tensor,
                global_step=global_step, walltime=walltime,
                rescale=rescale, dataformats=dataformats)

    def add_images(self, tag, img_tensor, global_step=None, walltime=None, dataformats='NCHW'):
        if self.active:
            global_step = self.global_step if global_step is None else global_step
            self.writer.add_images(
                tag, img_tensor, global_step=global_step, walltime=walltime, dataformats=dataformats)

    def add_mesh(self, tag, vertices, colors=None, faces=None, config_dict=None, global_step=None, walltime=None):
        if self.active:
            global_step = self.global_step if global_step is None else global_step
            self.writer.add_mesh(
                tag, vertices, colors=colors, faces=faces, config_dict=config_dict,
                global_step=global_step, walltime=walltime)

    def add_onnx_graph(self, graph):
        if self.active:
            self.writer.add_onnx_graph(graph)

    def add_pr_curve(self, tag, labels, predictions, global_step=None,
                     num_thresholds=127, weights=None, walltime=None):
        if self.active:
            global_step = self.global_step if global_step is None else global_step
            self.writer.add_pr_curve(
                tag, labels, predictions, global_step=global_step,
                num_thresholds=num_thresholds, weights=weights, walltime=walltime)

    def add_pr_curve_raw(self, tag, true_positive_counts,
                         false_positive_counts,
                         true_negative_counts,
                         false_negative_counts,
                         precision,
                         recall,
                         global_step=None,
                         num_thresholds=127,
                         weights=None,
                         walltime=None):
        if self.active:
            global_step = self.global_step if global_step is None else global_step
            self.writer.add_pr_curve_raw(
                tag, true_positive_counts,
                false_positive_counts,
                true_negative_counts,
                false_negative_counts,
                precision,
                recall,
                global_step=global_step,
                num_thresholds=num_thresholds,
                weights=weights,
                walltime=walltime)

    def add_scalar(self, tag, scalar_value, global_step=None, walltime=None):
        if self.active:
            global_step = self.global_step if global_step is None else global_step
            self.writer.add_scalar(
                tag, scalar_value, global_step=global_step, walltime=walltime)

    def add_scalars(self, main_tag, tag_scalar_dict, global_step=None, walltime=None):
        if self.active:
            global_step = self.global_step if global_step is None else global_step
            self.writer.add_scalars(
                main_tag, tag_scalar_dict, global_step=global_step, walltime=walltime)

    def add_text(self, tag, text_string, global_step=None, walltime=None):
        if self.active:
            global_step = self.global_step if global_step is None else global_step
            self.writer.add_text(
                tag, text_string, global_step=global_step, walltime=walltime)

    def add_video(self, tag, vid_tensor, global_step=None, fps=4, walltime=None):
        if self.active:
            global_step = self.global_step if global_step is None else global_step
            self.writer.add_video(
                tag, vid_tensor, global_step=global_step, fps=fps, walltime=walltime)

    def close(self):
        self.writer.close()

    def __enter__(self):
        return self.writer.__enter__()

    def __exit__(self, exc_type, exc_val, exc_tb):
        return self.writer.__exit__(exc_type, exc_val, exc_tb)
예제 #4
0
def train():
    # train on the GPU or on the CPU, if a GPU is not available
    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')
    time_stamp = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    stamp = "relation" + '_b' + str(batch_size) + '_e' + str(
        num_epochs) + '_lr' + str(learning_rate) + '_tr' + str(train_set_ratio)
    writer = SummaryWriter(comment=stamp)

    # use our dataset and defined transformations
    dataset = Labelme_Dataset(data_path,
                              cls2id=cls2id,
                              transforms=get_transform(train=True))
    dataset_test = Labelme_Dataset(data_path,
                                   cls2id=cls2id,
                                   transforms=get_transform(train=False))

    # split the dataset in train and test set
    num_train_set = round(train_set_ratio * len(dataset))
    indices = torch.randperm(len(dataset)).tolist()
    dataset = torch.utils.data.Subset(dataset, indices[:num_train_set])
    dataset_test = torch.utils.data.Subset(dataset_test,
                                           indices[num_train_set:])

    # define training and validation data loaders
    data_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=batch_size,
                                              shuffle=True,
                                              num_workers=num_workers,
                                              collate_fn=utils.collate_fn)

    data_loader_test = torch.utils.data.DataLoader(dataset_test,
                                                   batch_size=batch_size,
                                                   shuffle=False,
                                                   num_workers=num_workers,
                                                   collate_fn=utils.collate_fn)

    # get the model using our helper function
    model = get_model(num_classes)

    # move model to the right device
    model.to(device)

    # construct an optimizer
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.SGD(params,
                                lr=learning_rate,
                                momentum=momentum,
                                weight_decay=weight_decay)
    # and a learning rate scheduler
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                              mode='max',
                                                              patience=10)

    max_AP = 0
    max_acc = 0
    cm_when_max_acc = 0
    ps_when_max_acc = 0
    rc_when_max_acc = 0
    for epoch in range(num_epochs):
        # train for one epoch, printing every 10 iterations
        logger = train_one_epoch(model,
                                 optimizer,
                                 data_loader,
                                 device,
                                 epoch,
                                 print_freq=10)
        writer.add_scalar('epoch-loss', logger.loss.total, epoch)
        # evaluate on the test dataset
        evaler, pick = evaluate(model, data_loader_test, device=device)
        # 在TensorBoard中显示参数和测试结果
        writer.add_scalar('mAP', evaler.coco_eval['bbox'].stats[0], epoch)
        writer.add_scalar('AP50', evaler.coco_eval['bbox'].stats[1], epoch)
        writer.add_scalar('AP75', evaler.coco_eval['bbox'].stats[2], epoch)
        bboxes = pick[1][0][
            'boxes']  # 如果输出是X,Y,W,H 就要多一行代码: bboxes[:, 2:] += bboxes[:, :2]
        preds = ["background"] + list(cls2id.keys())
        labels = pick[1][0]['labels']
        scores = pick[1][0]['scores']
        bboxes_str = []
        for i in range(len(labels)):
            bboxes_str.append("{} {:.2%}".format(preds[int(labels[i])],
                                                 float(scores[i])))
        # tensorboard中输出mask,若需要可取消注释
        # masks = pick[1][0]['masks']
        # final_mask = torch.zeros(masks.shape[-2:], dtype=torch.bool)
        # for i in masks > 0.5:
        #     final_mask = torch.bitwise_or(i[0], final_mask)
        # writer.add_image('test_mask', final_mask, epoch, dataformats='HW')

        writer.add_image('real_img', pick[0][0], epoch)
        writer.add_image_with_boxes('test_img',
                                    pick[0][0],
                                    bboxes,
                                    epoch,
                                    labels=bboxes_str)

        # 在测试集上的AP50以及准确率作为评测指标来选取最佳模型,保存最后一个最佳模型(AP50或准确率)
        if max_AP < evaler.coco_eval['bbox'].stats[1]:
            max_AP = evaler.coco_eval['bbox'].stats[1]
            torch.save(
                model,
                os.path.join(model_save_path,
                             stamp + '_' + time_stamp + '.pth'))
        acc, cm, ps, rc = custom_eval_sanan_when_train(model,
                                                       data_loader_test,
                                                       device=device)
        writer.add_scalar('acc', acc, epoch)
        if max_acc < acc:
            max_acc = acc
            cm_when_max_acc = cm
            ps_when_max_acc = ps
            rc_when_max_acc = rc
            torch.save(
                model,
                os.path.join(model_save_path,
                             stamp + '_' + time_stamp + '.pth'))
        # update the learning rate
        writer.add_scalar('lr', logger.lr.value, epoch)
        lr_scheduler.step(evaler.coco_eval['bbox'].stats[1]
                          )  # AP50连续patience个epoch不超过当前最大值,就降低学习率
    print("That's it!")
    print('acc:{:.2%}'.format(max_acc))
    print(cm_when_max_acc)
    print('precision:', ['{:.2%}'.format(x) for x in ps_when_max_acc])
    print('recall:', ['{:.2%}'.format(x) for x in rc_when_max_acc])
def train_model(args):
    writer = SummaryWriter()

    transforms = DetectionTransform(output_size=args.resize,
                                    greyscale=True,
                                    normalize=True)

    dataset = FBSDetectionDataset(database_path=args.db,
                                  data_path=args.images,
                                  greyscale=True,
                                  transforms=transforms,
                                  categories_filter={'person': True},
                                  area_filter=[100**2, 500**2])
    dataset.print_categories()
    ''' 
    Split dataset into train and validation
    '''
    train_len = int(0.75 * len(dataset))
    dataset_lens = [train_len, len(dataset) - train_len]
    print("Splitting dataset into pieces: ", dataset_lens)
    datasets = torch.utils.data.random_split(dataset, dataset_lens)
    print(datasets)
    '''
    Setup the data loader objects (collation, batching)
    '''
    loader = torch.utils.data.DataLoader(collate_fn=collate_detection_samples,
                                         dataset=datasets[0],
                                         batch_size=args.batch_size,
                                         pin_memory=True,
                                         num_workers=args.num_data_workers)

    validation_loader = torch.utils.data.DataLoader(
        dataset=datasets[1],
        batch_size=args.batch_size,
        pin_memory=True,
        collate_fn=collate_detection_samples,
        num_workers=args.num_data_workers)
    '''
    Select device (cpu/gpu)
    '''
    device = torch.device(args.device)
    '''
    Create the model and transfer weights to device
    '''
    model = ObjectDetection(input_image_shape=args.resize,
                            pos_threshold=args.pos_anchor_iou,
                            neg_threshold=args.neg_anchor_iou,
                            num_classes=len(dataset.categories),
                            predict_conf_threshold=0.5).to(device)
    '''
    Select optimizer
    '''
    optim = torch.optim.SGD(params=model.parameters(),
                            lr=args.lr,
                            momentum=0.5)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optim,
                                                   step_size=2,
                                                   gamma=0.1,
                                                   last_epoch=-1)
    '''
    Outer training loop
    '''
    for epoch in range(1, args.epochs + 1):
        '''
        Inner training loop
        '''
        print("\n BEGINNING TRAINING STEP EPOCH {}".format(epoch))
        cummulative_loss = 0.0
        start_time = time.time()

        batch: ObjectDetectionBatch
        for idx, batch in enumerate(loader):
            '''
            Reset gradient
            '''
            optim.zero_grad()
            '''
            Push the data to the gpu (if necessary)
            '''
            batch.to(device)
            batch.debug = True if idx % args.log_interval == 0 else False
            '''
            Run the model
            '''
            losses, model_data = model(batch)
            cummulative_loss += losses["class_loss"].item()
            '''
            Calc gradient and step optimizer.
            '''
            losses['class_loss'].backward()
            optim.step()
            '''
            Log Metrics and Visualizations
            '''
            if (idx + 1) % args.log_interval == 0:
                step = (epoch - 1) * len(loader) + idx + 1

                print(
                    "Ep {} Training Step {} Batch {}/{} Loss : {:.3f}".format(
                        epoch, step, idx, len(loader), cummulative_loss))
                '''
                Save visualizations and metrics with tensorboard
                Note: For research, to reproduce graphs you will want some way to save the collected metrics (e.g. the loss values)
                to an array for recreating figures for a paper. To do so, metrics are often wrapped in a "metering" class
                that takes care of logging to tensorboard, resetting cumulative metrics, saving arrays, etc.
                '''
                '''
                training_image - the raw training images with box labels

                training_image_predicted_anchors - predictions for the same image, using basic thresholding (0.7 confidence on the logit)
                
                training_image_predicted_post_nms - predictions for the same image, filtered at 0.7 confidence followed by Non-Max-Suppression

                training_image_positive_anchors - shows anchors which received a positive label in the labeling step in the model
                '''
                sample_image = normalize_tensor(batch.images[0])
                writer.add_image_with_boxes("training_image",
                                            sample_image,
                                            box_tensor=batch.boxes[0],
                                            global_step=step)
                writer.add_image_with_boxes(
                    "training_image_predicted_anchors",
                    sample_image,
                    model_data["pos_predicted_anchors"][0],
                    global_step=step)

                keep_ind = nms(model_data["pos_predicted_anchors"][0],
                               model_data["pos_predicted_confidence"][0],
                               iou_threshold=args.nms_iou)

                writer.add_image_with_boxes(
                    "training_image_predicted_post_nms",
                    sample_image,
                    model_data["pos_predicted_anchors"][0][keep_ind],
                    global_step=step)
                writer.add_image_with_boxes(
                    "training_image_positive_anchors",
                    sample_image,
                    box_tensor=model_data["pos_labeled_anchors"][0],
                    global_step=step)
                '''
                Scalars - batch_time, training loss
                '''
                writer.add_scalar(
                    "batch_time",
                    ((time.time() - start_time) / float(args.log_interval)) *
                    1000.0,
                    global_step=step)
                writer.add_scalar("training_loss",
                                  losses['class_loss'].item(),
                                  global_step=step)

                writer.add_scalar(
                    "avg_pos_labeled_anchor_conf",
                    torch.tensor([
                        c.mean() for c in model_data["pos_labeled_confidence"]
                    ]).mean().item(),
                    global_step=step)

                start_time = time.time()

                writer.close()
            '''
            Reset metric meters as necessary
            '''
            if idx % args.metric_interval == 0:
                cummulative_loss = 0.0
        '''
        Inner validation loop
        '''
        print("\nBEGINNING VALIDATION STEP {}\n".format(epoch))
        with torch.no_grad():
            batch: ObjectDetectionBatch
            for idx, batch in enumerate(validation_loader):
                '''
                Push the data to the gpu (if necessary)
                '''
                batch.to(device)
                batch.debug = True if idx % args.log_interval == 0 else False
                '''
                Run the model
                '''
                losses, model_data = model(batch)

                if idx % args.log_interval == 0:
                    step = (epoch - 1) * len(validation_loader) + idx + 1

                    print("Ep {} Validation Step {} Batch {}/{} Loss : {:.3f}".
                          format(epoch, step, idx, len(validation_loader),
                                 losses["class_loss"].item()))
                    '''
                    Log Images
                    '''
                    sample_image = normalize_tensor(batch.images[0])
                    writer.add_image_with_boxes("validation_images",
                                                sample_image,
                                                box_tensor=batch.boxes[0],
                                                global_step=step)

                    writer.add_image_with_boxes(
                        "validation_img_predicted_anchors",
                        sample_image,
                        model_data["pos_predicted_anchors"][0],
                        global_step=step)

                    keep_ind = nms(model_data["pos_predicted_anchors"][0],
                                   model_data["pos_predicted_confidence"][0],
                                   iou_threshold=0.5)
                    print("Indicies after NMS: ", keep_ind,
                          model_data["pos_predicted_confidence"][0].shape,
                          model_data["pos_predicted_anchors"][0].shape)

                    writer.add_image_with_boxes(
                        "validation_img_predicted_post_nms",
                        sample_image,
                        model_data["pos_predicted_anchors"][0][keep_ind],
                        global_step=step)
                    '''
                    Log Scalars
                    '''
                    writer.add_scalar("validation_loss",
                                      losses['class_loss'].item(),
                                      global_step=step)
                    writer.close()

        lr_scheduler.step()
        print("Stepped learning rate. Rate is now: ", lr_scheduler.get_lr())
예제 #6
0
class Trainer:
    '''
    a class for train SPAIR.
    with support of drawing boundary, logging and scoring.
    '''
    def __init__(self, Implement=SPAIR, **config):
        save_cfg = self.checkpoint(config.get("model_path", None))
        if save_cfg:
            self.loadConfig(config, save_cfg["config"])
            self.start_epoch = save_cfg["epoch"]
            self.spair = Implement(**self.config.get('model', {}))
            self.spair.load_state_dict(save_cfg["model"])

        else:
            self.loadConfig(config)
            self.start_epoch = 0
            self.spair = Implement(**self.config.get('model', {}))

        self.kl_bulder = KL_Builder(**self.config['KL_Builder'])
        self.op = eval(self.optimizer)(self.spair.parameters(),
                                       **self.config.get(self.optimizer, {}))
        self.spair.to(self.device)

    @staticmethod
    def checkpoint(path) -> dict:
        if not path: return
        if os.path.exists(path):
            return torch.load(path)
        else:
            folder = os.path.dirname(path)
            if not os.path.exists(folder):
                os.makedirs(
                    folder)  # make sure `save` will not raise exception

    def loadConfig(self, config, default: dict = None):
        if default: self.config = default.copy()
        else: self.config = {}
        self.config.update(config)
        self.summary = SummaryWriter(self.config["logdir"],
                                     self.config["logdir"].split('/')[-1])
        self.device = torch.device(self.config.get('device', 'cuda:0'))
        self.optimizer = self.config.get('optimizer', 'Adam')
        self._sup = self.config.get('KL_lambda', 1)

    def save(self):
        d = {
            "config": self.config,
            "epoch": self.cur_epoch,
            "model": self.spair.state_dict()
        }

        folder = os.path.dirname(self.config["model_path"])
        if not os.path.exists(folder):
            os.makedirs(folder)  # make sure `save` will not raise exception

        torch.save(d, self.config["model_path"])

    def loss(self, rec, target, param_dict: dict, global_step) -> torch.Tensor:
        recon_loss = binary_cross_entropy(rec, target, reduction='sum')
        norm_loss = self.kl_bulder.norm_KL(param_dict)
        kin_loss = self.kl_bulder.bin_KL(param_dict['pres'], global_step)
        assert not torch.isnan(recon_loss)
        self.summary.add_scalar('loss/reconstruct', recon_loss, global_step)
        self.summary.add_scalar('loss/normal', norm_loss, global_step)
        self.summary.add_scalar('loss/bernoulli', kin_loss, global_step)
        return recon_loss + self._sup * (norm_loss + kin_loss)

    def __histogram(self,
                    tag: str,
                    value: torch.Tensor,
                    global_step=None,
                    dim=None):
        with torch.no_grad():
            if dim:
                for i, v in enumerate(torch.split(value, 1, dim=dim)):
                    self.summary.add_histogram(
                        "sample/%s/%s_%d" % (tag, tag, i), v, global_step)
            else:
                self.summary.add_histogram("sample/%s" % tag, value,
                                           global_step)

    def rectangle(self, rec, norm_box, pres, global_step):
        '''
        rec: [N, C, H_img, W_img]
        norm_box: [N, H*W, 4], y_center, x_center, height, width
        pres: [N, H*W, 1]
        '''
        with torch.no_grad():
            norm_box[:, :, :2] *= self.spair.encoder.image_size
            norm_box[:, :, 2:] *= self.spair.encoder.image_size / 2
            norm_box = torch.round_(
                norm_box)  # y_center, x_center, height / 2, width / 2
            norm_box = torch.stack(
                (
                    norm_box[:, :, 1] - norm_box[:, :, 3],  # xmin
                    norm_box[:, :, 0] - norm_box[:, :, 2],  # ymin
                    norm_box[:, :, 1] + norm_box[:, :, 3],  # xmax
                    norm_box[:, :, 0] + norm_box[:, :, 2],  # ymax
                ),
                dim=-1)
            pres = torch.round_(pres).bool().squeeze_(-1)

            for i, (img, box, zpres) in enumerate(zip(rec, norm_box, pres)):
                box = torch.stack([b for b, z in zip(box, zpres) if z])
                self.summary.add_image_with_boxes('detect/rec_%d' % i, img,
                                                  box, global_step)

    def reconstruct(self, X, bg=None):
        '''
        X: [N, H, W, C]
        return: [N, H, W, C]
        '''
        self.spair.eval()
        if X.dim() == 3: X = X.unsqueeze_(0)
        prev_device = X.device
        with torch.no_grad():
            return self.spair(
                X.to(self.device).permute(0, 3, 1, 2),
                bg.permute(0, 3, 1, 2)).to(prev_device)

    def train(self, X: torch.Tensor, bg=None):
        '''
        X shape: [N, H, W, C];
        '''
        # torch.autograd.set_detect_anomaly(True)
        if bg is None: bg = torch.zeros_like(X)
        data = TensorDataset(
            X.permute(0, 3, 1, 2).to(self.device),  # [N, 3, H, W]
            bg.permute(0, 3, 1, 2).to(self.device)  # [N, 3, H, W]
        )
        loader = DataLoader(data, **self.config.get("loader", {}))
        max_batch = len(loader)

        # ==============================
        # ========= train here =========
        def one_epoch():
            citer = max_batch * self.cur_epoch
            for batch, (R, B) in enumerate(loader):
                self.op.zero_grad()
                where, what, depth, pres, pd = self.spair.encoder(R)
                self.__histogram('where', pd['where'].mean, citer + batch, -1)
                self.__histogram('what', what, citer + batch)
                self.__histogram('depth', depth, citer + batch)
                self.__histogram('pres', pres, citer + batch)
                rec = self.spair.decoder(where, what, depth, pres, B)
                loss = self.loss(rec, R, pd, citer + batch)
                loss.backward()
                self.op.step()
                bar.next()

        # ========= train end ==========
        # ==============================

        for self.cur_epoch in range(self.start_epoch,
                                    self.config["max_epoch"]):
            bar = Bar('epoch%3d' % (self.cur_epoch + 1), max=max_batch)
            try:
                self.spair.train()
                one_epoch()
                yield self.cur_epoch  # see what my caller wanna do after one epoch
                bar.finish()
            except KeyboardInterrupt:
                bar.finish()
                if 'Y' == input('save model? Y/n ').upper():
                    self.save()
                    print("Saved. Start from epoch%d next time." %
                          (self.cur_epoch + 1))
                return

        self.cur_epoch = self.config["max_epoch"]
        self.save()

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.summary.close()
class ObjectDetectionTrainer:
    """Object detection training manager.
    """
    def __init__(self,
                 max_epoch: int,
                 dataset_train_path: str,
                 dataset_validation_path: str,
                 output_grid_size: int,
                 anchor_boxes: List[BoundingBox],
                 target_device: str,
                 label_map: Dict[str, int],
                 target_image_size: Tuple[int, int] = (320, 320),
                 n_classes: int = 1):
        """Create a new ObjectDetectionTrainer

        Args:
            max_epoch (int): Maximum number of epoch to run the training.
            dataset_path (str): Path to the folder containing the json files.
            output_grid_size (int): Number of grids for the prediction. The model will output (output_grid_size, output_grid_size) cells.
            anchor_boxes (List[BoundingBox]): List of bounding box to be used as anchor.
            target_device (str): The target device of that Pytorch will use. Example ("cuda:0", "cpu:0")
            target_image_size (Tuple[int,int], optional): The expected size of the network input (width, height). Defaults to (320,320).
        """
        self.n_classes = n_classes
        self.max_epoch = max_epoch
        self.target_image_size = target_image_size
        self.dataset_train = ObjectDetectionDataset(
            dataset_train_path,
            target_image_size,
            transform=image_bb_transforms,
            label_map=label_map)
        self.dataset_validation = ObjectDetectionDataset(
            dataset_validation_path,
            target_image_size,
            transform=image_bb_transforms,
            label_map=label_map)
        self.boxes_per_cell = len(anchor_boxes)
        self.anchor_boxes = anchor_boxes
        self.output_grid_size = output_grid_size
        self.model = ObjectDetectionModel(n_classes, self.output_grid_size,
                                          self.boxes_per_cell)
        self.target_device = target_device
        self.model.to(self.target_device)
        self.optimizer = Adam(self.model.parameters(), lr=1e-3)
        self.label_map = label_map

        # Timestamp to assign unique id to logs
        now = datetime.now()
        date_time = now.strftime("%m_%d_%Y_%H_%M_%S")

        self.run_id = date_time

        self.writer = SummaryWriter("logs/%s" % (self.run_id))
        self.scheduler = ReduceLROnPlateau(self.optimizer,
                                           'min',
                                           verbose=True,
                                           patience=10)

    def calculate_iou(self, b1: BoundingBox, b2: BoundingBox) -> float:
        """Calculate intersection over union

        Args:
            b1 (BoundingBox): First bounding box
            b2 (BoundingBox): Second bounding box

        Returns:
            float: iou score between the boxes
        """
        x_a = max(b1.min_x, b2.min_x)
        y_a = max(b1.min_y, b2.min_y)

        x_b = min(b1.max_x, b2.max_x)
        y_b = min(b1.max_y, b2.max_y)

        intersection = max(0, x_b - x_a + 1) * max(0, y_b - y_a + 1)

        area_b1 = (b1.max_x - b1.min_x + 1) * (b1.max_y - b1.min_y + 1)
        area_b2 = (b2.max_x - b2.min_x + 1) * (b2.max_y - b2.min_y + 1)

        iou = intersection / float(area_b1 + area_b2 - intersection)

        return iou

    # assume output is (batch_size*n_cells, cell_tensor_size)
    def parse_the_outputs(self, output: Type[torch.FloatTensor],
                          boxes: List[List[BoundingBox]],
                          classes: List[List[int]]):
        """Takes the output of the neural network and the target bounding boxes

        Args:
            output (torch.FloatTensor): The output from the neural network, reshaped as (batch_size*n_cells, cell_tensor_size)
            boxes (List[List[BoundingBox]]): List of list of target bounding boxes for each image
            clasess (List[int]): List of list of target classes for each image

        Returns:
            [type]: [description]
        """
        # Each cell has [ self.boxes_per_cell * 5 + n_classes ] entries, extract each relevant part
        bbox_slices = output[:, :self.boxes_per_cell * 5].reshape(-1, 5)
        class_slices = output[:, self.boxes_per_cell * 5:]

        positive_bbox_outputs = []
        positive_bbox_targets = []
        negative_bbox_outputs = []

        positive_classprob_indices = []
        positive_classprob_target = []
        positive_classprob_output = []

        anchor_ids = []
        positive_bbox_indices_list = []

        # total number of cells in the batch output
        n_cells = output.shape[0]

        # gather positive cell indices and corresponding anchor boxes
        for image_id, bb_list in enumerate(boxes):
            # Process all bbs in the image
            for bb_id, bb in enumerate(bb_list):

                # Center in output grid coordinate
                center_x_grid = self.output_grid_size * (bb.min_x +
                                                         bb.max_x) / 2
                center_y_grid = self.output_grid_size * (bb.min_y +
                                                         bb.max_y) / 2

                # Get cell index in the batch
                cell_index = math.floor(
                    center_y_grid) * self.output_grid_size + math.floor(
                        center_x_grid)
                cell_index += image_id * self.model.n_cells

                # Calc bounding box offset within the grid
                offset_x = center_x_grid - math.floor(center_x_grid)
                offset_y = center_y_grid - math.floor(center_y_grid)

                # Calc bb size
                width = (bb.max_x - bb.min_x)
                height = (bb.max_y - bb.min_y)

                matching_anchors = []

                max_anchor_id = -1

                anchor_offset_x = (bb.min_x + bb.max_x) / 2
                anchor_offset_y = (bb.min_y + bb.max_y) / 2

                # Find matching matching boxes
                for anchor_id, anchor in enumerate(self.anchor_boxes):
                    adjusted_anchor = BoundingBox(
                        min_x=anchor.min_x + anchor_offset_x,
                        min_y=anchor.min_y + anchor_offset_y,
                        max_x=anchor.max_x + anchor_offset_x,
                        max_y=anchor.max_y + anchor_offset_y)
                    iou = self.calculate_iou(bb, adjusted_anchor)
                    if iou > 0.5:
                        matching_anchors.append((iou, anchor_id))

                # Find the maximum matching anchor
                if matching_anchors:
                    max_anchor_id = max(matching_anchors,
                                        key=lambda t: t[0])[1]

                    anchor = self.anchor_boxes[max_anchor_id]

                    anchor_w = anchor.max_x - anchor.min_x
                    anchor_h = anchor.max_y - anchor.min_y

                    # calc bbox index
                    bbox_index = cell_index * self.model.boxes_per_cell + max_anchor_id

                    # Calc target tensor
                    target_tensor = torch.FloatTensor(
                        [offset_x, offset_y, width, height, 1.0])

                    positive_bbox_targets.append(target_tensor)

                    positive_bbox_indices_list.append(bbox_index)

                    # class prob target
                    #class_target = torch.zeros(self.n_classes)
                    #class_target[classes[image_id][bb_id]] = 1.0

                    positive_classprob_output.append(class_slices[cell_index])
                    positive_classprob_target.append(classes[image_id][bb_id])

                    anchor_ids.append(max_anchor_id)

        # Make a set for quick membership lookup
        positive_bbox_indices_set = set(positive_bbox_indices_list)
        # Other than positive indices, combine as negative indices
        negative_bbox_indices_list = torch.LongTensor([
            s for s in range(bbox_slices.shape[0])
            if s not in positive_bbox_indices_set
        ])

        # Get positive bbox and negative bbox
        positive_bbox_outputs = bbox_slices[positive_bbox_indices_list, :]
        negative_bbox_outputs = bbox_slices[negative_bbox_indices_list, :]

        positive_classprob_output = torch.stack(positive_classprob_output)
        positive_classprob_target = torch.LongTensor(
            positive_classprob_target).to(self.target_device)

        # Stack together bbox target tensors
        bbox_targets = torch.stack(positive_bbox_targets,
                                   0).to(self.target_device)

        # Keep track of anchor sizes. This will be used as scaling factor for bbox size
        anchor_sizes = [[
            self.anchor_boxes[s].max_x - self.anchor_boxes[s].min_x,
            self.anchor_boxes[s].max_y - self.anchor_boxes[s].min_y
        ] for s in anchor_ids]
        anchor_sizes = np.array(anchor_sizes)

        return positive_bbox_outputs, negative_bbox_outputs, bbox_targets, positive_classprob_output, positive_classprob_target, anchor_sizes

    def preview_generate(self, input_img: Type[torch.FloatTensor],
                         output: Type[torch.FloatTensor],
                         target_box: List[BoundingBox]):
        """Detection results preview generator

        Args:
            input_img (Type[torch.FloatTensor]): Input image (C,H,W)
            output (Type[torch.FloatTensor]) : Network output (-1,CELL_SIZE)
            target_box (List[BoundingBox]): List of the boxes in image[0]
        """
        np_image = np_image_from_tensor(input_img)

        img_height = np_image.shape[0]
        img_width = np_image.shape[1]

        # Reshape cells to 2D
        bb_outputs = output.view(self.model.output_grid_size,
                                 self.model.output_grid_size,
                                 -1)[:, :, :self.boxes_per_cell * 5]
        bb_outputs = bb_outputs.reshape(self.model.output_grid_size,
                                        self.model.output_grid_size,
                                        self.model.boxes_per_cell, -1)

        # Get objectness outputs
        objectness_output = torch.sigmoid(bb_outputs[:, :, :, -1])

        # Only select if objectness > 0.5, will return list of index tuples
        candidate_cell_index = torch.nonzero(
            objectness_output > 0.1).detach().cpu().numpy().tolist()

        # Bounding box outputs of candidate boxes
        candidates_bb = []

        # Cell sizes in normalized coordinate
        cell_size_x = 1.0 / self.model.output_grid_size
        cell_size_y = 1.0 / self.model.output_grid_size

        # Bounding boxes showing candidate cells
        cell_bboxes = []

        for idx in candidate_cell_index:
            # Get output of current candidate cell
            bb = bb_outputs[idx[0], idx[1], idx[2]][:4]

            # idx[2] corresponds to the anchor idx
            anchor_w = self.anchor_boxes[idx[2]].max_x - self.anchor_boxes[
                idx[2]].min_x
            anchor_h = self.anchor_boxes[idx[2]].max_y - self.anchor_boxes[
                idx[2]].min_y

            # Compute offset and sizes
            bb[:2] = torch.sigmoid(bb[:2])
            bb[2:4] = torch.exp(bb[2:4])
            bb[2] *= anchor_w
            bb[3] *= anchor_h

            # Compute the origin of the cell coordinate in pixel coordinate
            base_x = cell_size_x * (idx[1]) * img_width
            base_y = cell_size_y * (idx[0]) * img_height

            preview_bb = bb.detach()

            preview_bb[0] = base_x + preview_bb[
                0] * cell_size_x * img_width - preview_bb[2] * img_width / 2.0
            preview_bb[1] = base_y + preview_bb[
                1] * cell_size_y * img_height - preview_bb[3] * img_height / 2.0

            preview_bb[2] = preview_bb[0] + preview_bb[2] * img_width
            preview_bb[3] = preview_bb[1] + preview_bb[3] * img_height

            preview_bb[[0, 2]] = torch.clamp_min(preview_bb[[0, 2]], 0.0)
            preview_bb[[0, 2]] = torch.clamp_max(preview_bb[[0, 2]],
                                                 img_width - 1)
            preview_bb[[1, 3]] = torch.clamp_min(preview_bb[[1, 3]], 0.0)
            preview_bb[[1, 3]] = torch.clamp_max(preview_bb[[1, 3]],
                                                 img_height - 1)

            candidates_bb.append(preview_bb)

            cell_bbox = np.array([
                base_x, base_y, base_x + cell_size_x * img_width,
                base_y + cell_size_y * img_height
            ])
            cell_bboxes.append(cell_bbox)

        target_boxes = []
        for target in target_box:
            bb_ref = list(target)
            bb_ref = np.array(bb_ref)

            # Convert to pixel coordinate
            bb_ref[[0, 2]] *= img_width
            bb_ref[[1, 3]] *= img_height

            target_boxes.append(bb_ref)

        if len(cell_bboxes) > 0:
            cell_bboxes = np.stack(cell_bboxes)
        else:
            cell_bboxes = torch.FloatTensor([])
            #self.writer.add_image_with_boxes("Preview/cell", np.array(image), cell_bboxes, step, dataformats='HWC')
        #else:
        #    self.writer.add_image("Preview/cell", np.array(image), step, dataformats='HWC')

        if len(target_boxes) > 0:
            target_boxes = np.stack(target_boxes)
        else:
            target_boxes = torch.FloatTensor([])
            #self.writer.add_image_with_boxes("Preview/target", np.array(image), target_boxes, step, dataformats='HWC')

        if len(candidates_bb) > 0:
            candidates_bb = torch.stack(candidates_bb, dim=0)
        else:
            candidates_bb = torch.FloatTensor([])
            #self.writer.add_image_with_boxes("Preview/detection", np.array(image), candidates_bb, step, dataformats='HWC')
        #else:
        #    self.writer.add_image("Preview/detection", np.array(image), step, dataformats='HWC')

        preview_info = {
            "image": np_image,
            "target_bb": target_boxes,
            "candidate_bb": candidates_bb,
            "cell_bb": cell_bboxes
        }

        return preview_info

    def run_batch(self, batch_data: Type[torch.Tensor], epoch: int):
        """Train batch

        Args:
            batch_data (Type[torch.Tensor]): Batch of data
            epoch (int): epoch number

        Returns:
            Dict: Dict of losses
        """
        images = batch_data["images"]
        boxes = batch_data["bboxes"]
        classes = batch_data["labels"]

        images = images.float().to(self.target_device)

        self.optimizer.zero_grad()

        # run network
        output = self.model(images)

        # Reshape to cell outputs
        cell_outputs = output.view(-1, self.model.cell_tensor_size)

        # Parse the cell outputs
        positive_bbox, negative_bbox, positive_bbox_targets, positive_classprob_output, positive_classprob_target, anchor_sizes = self.parse_the_outputs(
            cell_outputs, boxes, classes)

        # Cell structure, ( box1, box2, box3, ..., objectness, class_1, class_2, ... )
        positive_objectness_output = positive_bbox[:, -1].flatten()
        negative_objectness_output = negative_bbox[:, -1].flatten()
        objectness_output = torch.cat(
            [positive_objectness_output, negative_objectness_output], dim=0)

        # generate target objectness tensor
        objectness_target_positive = torch.ones(
            positive_objectness_output.shape[0]).to(self.target_device)
        objectness_target_negative = torch.zeros(
            negative_objectness_output.shape[0]).to(self.target_device)

        objectness_positive_loss = nn.functional.mse_loss(
            input=torch.sigmoid(positive_objectness_output),
            target=objectness_target_positive,
            reduction="sum")
        objectness_negative_loss = nn.functional.mse_loss(
            input=torch.sigmoid(negative_objectness_output),
            target=objectness_target_negative,
            reduction="sum")

        # Apply weight to the losses to handle class imbalance
        objectness_loss = objectness_positive_loss + 0.5 * objectness_negative_loss

        # Update offset to sigmoid
        positive_bbox[:, :2] = torch.sigmoid(positive_bbox[:, :2])
        # width = exp(w) * anchor_w
        # height = exp(h) * anchor_h
        positive_bbox[:, 2:4] = torch.exp(positive_bbox[:, 2:4])
        positive_bbox[:, 2:4] *= torch.from_numpy(anchor_sizes).to(
            self.target_device)

        # Calc bb loss
        bb_loss_offset = nn.functional.mse_loss(positive_bbox[:, :2],
                                                positive_bbox_targets[:, :2],
                                                reduction='sum')
        bb_loss_dims = nn.functional.mse_loss(
            torch.sqrt(positive_bbox[:, 2:4]),
            torch.sqrt(positive_bbox_targets[:, 2:4]),
            reduction='sum')

        bb_loss = bb_loss_offset + bb_loss_dims

        # Calc classes prob loss
        classes_loss = nn.functional.nll_loss(
            input=nn.functional.log_softmax(positive_classprob_output),
            target=positive_classprob_target,
            reduction='sum')

        # Make bb_loss stronger to counter gradient from negative objectness
        total_loss = 1 * objectness_loss + 5 * bb_loss + classes_loss

        total_loss.backward()

        self.optimizer.step()

        preview_info = self.preview_generate(images[0], output[0], boxes[0])

        loss_info = {
            "objectness": objectness_loss.detach().item(),
            "bb": bb_loss.detach().item(),
            "class": classes_loss.detach().item(),
            "total": total_loss.detach().item()
        }

        return loss_info, preview_info

    def start_training(self):
        # Create dataloader to batch the dataset
        dataloader_train = DataLoader(self.dataset_train,
                                      32,
                                      True,
                                      collate_fn=object_dataset_collate_fn,
                                      num_workers=6,
                                      worker_init_fn=worker_init_fn)
        sampler = torch.utils.data.RandomSampler(self.dataset_validation,
                                                 replacement=True,
                                                 num_samples=32)
        dataloader_validation = DataLoader(
            self.dataset_validation,
            1,
            False,
            collate_fn=object_dataset_collate_fn,
            num_workers=1,
            worker_init_fn=worker_init_fn,
            sampler=sampler)

        for ep in range(self.max_epoch):
            print("epoch : {}".format(ep))
            objectness_loss = []
            bb_loss = []
            total_loss = []
            classes_loss = []

            # Train batches
            for idx, data in enumerate(dataloader_train):
                losses, preview_images = self.run_batch(data, idx)
                objectness_loss.append(losses["objectness"])
                bb_loss.append(losses["bb"])
                total_loss.append(losses["total"])
                classes_loss.append(losses["class"])

            mean_objectness_loss = np.mean(objectness_loss)
            mean_bb_loss = np.mean(bb_loss)
            mean_total_loss = np.mean(total_loss)
            mean_classes_loss = np.mean(classes_loss)

            self.writer.add_scalar("training/mean_objectness_loss",
                                   mean_objectness_loss, ep)
            self.writer.add_scalar("training/mean_bb_loss", mean_bb_loss, ep)
            self.writer.add_scalar("training/mean_classes_loss",
                                   mean_classes_loss, ep)
            self.writer.add_scalar("training/mean_total_loss", mean_total_loss,
                                   ep)

            self.writer.add_image_with_boxes("training/target",
                                             preview_images["image"],
                                             preview_images["target_bb"],
                                             ep,
                                             dataformats='HWC')
            self.writer.add_image_with_boxes("training/cell",
                                             preview_images["image"],
                                             preview_images["cell_bb"],
                                             ep,
                                             dataformats='HWC')
            self.writer.add_image_with_boxes("training/detection",
                                             preview_images["image"],
                                             preview_images["candidate_bb"],
                                             ep,
                                             dataformats='HWC')

            if ep % 20 == 0:
                val_objectness_losses = []
                val_bb_losses = []
                val_class_losses = []
                val_total_losses = []

                # Eval mode for validation
                self.model.eval()

                # Validation
                for idx, data in enumerate(dataloader_validation):
                    losses, preview_images = self.run_batch(data, idx)
                    val_objectness_losses.append(losses["objectness"])
                    val_bb_losses.append(losses["bb"])
                    val_total_losses.append(losses["total"])
                    val_class_losses.append(losses["class"])

                # Return to train mode
                self.model.train()

                # Update learning rate scheduler step
                self.scheduler.step(mean_total_loss)

                val_mean_objectness_loss = np.mean(val_objectness_losses)
                val_mean_bb_loss = np.mean(val_objectness_losses)
                val_mean_total_loss = np.mean(val_total_losses)
                val_mean_classes_loss = np.mean(val_class_losses)

                self.writer.add_scalar("validation/mean_objectness_loss",
                                       val_mean_objectness_loss, ep)
                self.writer.add_scalar("validation/mean_bb_loss",
                                       val_mean_bb_loss, ep)
                self.writer.add_scalar("validation/mean_classes_loss",
                                       val_mean_classes_loss, ep)
                self.writer.add_scalar("validation/mean_total_loss",
                                       val_mean_total_loss, ep)

                self.writer.add_image_with_boxes("validation/target",
                                                 preview_images["image"],
                                                 preview_images["target_bb"],
                                                 ep,
                                                 dataformats='HWC')
                self.writer.add_image_with_boxes("validation/cell",
                                                 preview_images["image"],
                                                 preview_images["cell_bb"],
                                                 ep,
                                                 dataformats='HWC')
                self.writer.add_image_with_boxes(
                    "validation/detection",
                    preview_images["image"],
                    preview_images["candidate_bb"],
                    ep,
                    dataformats='HWC')

                checkpoint_dir = "checkpoints/{}".format(self.run_id)
                if (not os.path.exists(checkpoint_dir)):
                    os.makedirs(checkpoint_dir)

                torch.save(
                    self.model.state_dict(),
                    "{}/{}_{}.pth".format(checkpoint_dir, ep, mean_total_loss))