def get_embeddings(image, net, device):
    transform = cvtransforms.Compose([
        cvtransforms.Resize((112, 112)),
        cvtransforms.RandomHorizontalFlip(),
        cvtransforms.ToTensor(),
        cvtransforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    transformed_image = transform(image).to(device)
    the_image = Variable(transformed_image).unsqueeze(0)
    # net.eval()
    embeddings = l2_norm(net.forward(the_image)).detach()  # remain data at gpu
    return embeddings
def get_normalization_mean_std_from_training_set(base_dataset_dict, train_idx,
                                                 device, train_transform_list,
                                                 n_batch):
    sampler = SubsetRandomSampler(train_idx)
    simpler_transform = [train_transform_list[0], train_transform_list[-1]]
    # print(simpler_transform)
    train_transforms = [
        compose_input_output_transform(
            input_transform=cvtransforms.Compose(simpler_transform)),
    ]
    base_dataset = base_dataset_dict["base_dataset"](
        img_dir=base_dataset_dict["datapath"],
        multi_label_gt_path=base_dataset_dict["gt_path"],
        transform=train_transforms[0])
    train_data_loader = DataLoader(dataset=base_dataset, sampler=sampler)
    train_data_stack = []
    for batch_idx, data in enumerate(train_data_loader):
        input = data['input'].to(device)
        # if train_data_stack.shape[0] == 0:
        train_data_stack.append(input.transpose_(0, -3).flatten(start_dim=1))
    torch_stacked_input = torch.cat(train_data_stack, dim=1)
    train_mean = torch_stacked_input.mean(dim=1).to(torch.device('cpu'))
    train_std = torch_stacked_input.std(dim=1).to(torch.device('cpu'))
    return train_mean, train_std
Esempio n. 3
0
def validate(args):
    # might as well try to validate something
    args.pretrained = args.pretrained or not args.checkpoint

    # create model
    model = create_model(args.model,
                         pretrained=args.pretrained,
                         num_classes=args.num_classes,
                         in_chans=3,
                         scriptable=args.torchscript)

    if args.checkpoint:
        load_checkpoint(model, args.checkpoint, args.use_ema)

    param_count = sum([m.numel() for m in model.parameters()])
    logging.info('Model %s created, param count: %d' %
                 (args.model, param_count))

    if args.torchscript:
        torch.jit.optimized_execution(True)
        model = torch.jit.script(model)

    model = model.cuda()

    if args.num_gpu > 1:
        model = torch.nn.DataParallel(model,
                                      device_ids=list(range(args.num_gpu)))

    criterion = nn.CrossEntropyLoss().cuda()

    # from torchvision.datasets import ImageNet
    # dataset = ImageNet(args.data, split='val')

    valdir = args.data
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    transform = cvtransforms.Compose([
        cvtransforms.Resize(size=(256), interpolation='BILINEAR'),
        cvtransforms.CenterCrop(224),
        cvtransforms.ToTensor(),
        cvtransforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

    # loader = torch.utils.data.DataLoader(
    #     datasets.ImageFolder(valdir, transform, loader=opencv_loader),
    #     batch_size=args.batch_size, shuffle=False,
    #     num_workers=args.workers, pin_memory=False)

    loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        valdir,
        transforms.Compose([
            transforms.Resize((256), interpolation=2),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
                                         batch_size=args.batch_size,
                                         shuffle=False,
                                         num_workers=args.workers,
                                         pin_memory=False)

    # loader_eval = loader.Loader('val', valdir, batch_size=args.batch_size, num_workers=args.workers, shuffle=False)

    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    model.eval()
    with torch.no_grad():
        # warmup, reduce variability of first batch time, especially for comparing torchscript vs non
        # input = torch.randn((args.batch_size,)).cuda()
        # model(input)
        end = time.time()
        for i, (input, target) in enumerate(loader):
            # if args.no_prefetcher:
            target = target.cuda()
            input = input.cuda()

            # compute output
            output, _ = model(input)
            # loss = criterion(output, target)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output.data, target, topk=(1, 5))
            # losses.update(loss.item(), input.size(0))
            top1.update(acc1.item(), input.size(0))
            top5.update(acc5.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.log_freq == 0:
                logging.info(
                    'Test: [{0:>4d}/{1}]  '
                    'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                    'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                    'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f})  '
                    'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
                        i,
                        len(loader),
                        batch_time=batch_time,
                        rate_avg=input.size(0) / batch_time.avg,
                        loss=losses,
                        top1=top1,
                        top5=top5))

    results = OrderedDict(top1=round(top1.avg, 4),
                          top1_err=round(100 - top1.avg, 4),
                          top5=round(top5.avg, 4),
                          top5_err=round(100 - top5.avg, 4),
                          param_count=round(param_count / 1e6, 2))

    logging.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(
        results['top1'], results['top1_err'], results['top5'],
        results['top5_err']))

    return results
Esempio n. 4
0
def valid(datacfg, cfgfile, weightfile, outfile):
    def truths_length(truths):
        for i in range(50):
            if truths[i][1] == 0:
                return i

    # Parse configuration files
    options      = read_data_cfg(datacfg)
    valid_images = options['valid']
    meshname     = options['mesh']
    backupdir    = options['backup']
    name         = options['name']
    if not os.path.exists(backupdir):
        makedirs(backupdir)

    # Parameters
    prefix       = 'results'
    seed         = int(time.time())
    gpus         = '0'     # Specify which gpus to use
    test_width   = 544
    test_height  = 544
    torch.manual_seed(seed)
    use_cuda = True
    if use_cuda:
        os.environ['CUDA_VISIBLE_DEVICES'] = gpus
        torch.cuda.manual_seed(seed)
    save            = False
    testtime        = True
    use_cuda        = True
    num_classes     = 1
    testing_samples = 0.0
    eps             = 1e-5
    notpredicted    = 0 
    conf_thresh     = 0.1
    nms_thresh      = 0.5 # was 0.4
    match_thresh    = 0.5
    y_dispay_thresh = 144
    # Try to load a previously generated yolo network graph in ONNX format:
    #onnx_file_path = './cargo_yolo2.onnx'
    #engine_file_path = './cargo_yolo2.trt'
    #onnx_file_path = './cargo_yolo2_c920_cam.onnx'
    #engine_file_path = './cargo_yolo2_c920_cam.trt'
    onnx_file_path = './cargo_yolo2_c920_cam_83percent.onnx'
    engine_file_path = './cargo_yolo2_c920_cam_83percent.trt'

    if save:
        makedirs(backupdir + '/test')
        makedirs(backupdir + '/test/pr')

    # To save
    testing_error_trans = 0.0
    testing_error_angle = 0.0
    testing_error_pixel = 0.0
    errs_2d             = []
    errs_3d             = []
    errs_trans          = []
    errs_angle          = []
    errs_corner2D       = []
    preds_trans         = []
    preds_rot           = []
    preds_corners2D     = []

    # Read object model information, get 3D bounding box corners
    mesh          = MeshPly(meshname)
    vertices      = np.c_[np.array(mesh.vertices), np.ones((len(mesh.vertices), 1))].transpose()
    print('vertices', vertices)
    corners3D     = get_3D_corners(vertices)
    print('corners3D', corners3D)
    # diam          = calc_pts_diameter(np.array(mesh.vertices))
    diam          = float(options['diam'])

    # Read intrinsic camera parameters
    internal_calibration = get_camera_intrinsic()
    dist = get_camera_distortion_mat()

    # Get validation file names
    with open(valid_images) as fp:
        tmp_files = fp.readlines()
        valid_files = [item.rstrip() for item in tmp_files]
    
    # Specicy model, load pretrained weights, pass to GPU and set the module in evaluation mode
    # comment out since we are loading TRT model using get_engine() function
    # model = Darknet(cfgfile)
    # model.print_network()
    # model.load_weights(weightfile)
    # model.cuda()
    # model.eval()
    model_input_size = [416, 416]
    # print('model.anchors', model.anchors)
    # print('model.num_anchors', model.num_anchors)

    # specify the webcam as camera
    colors = pkl.load(open("pallete", "rb"))
    cap = cv2.VideoCapture(0)
    cap.set(cv2.CAP_PROP_FRAME_WIDTH, 1280)
    cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 720)

    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = cap.get(cv2.CAP_PROP_FPS)


    #transform = transforms.Compose([transforms.ToTensor(),])

    transform = cvtransforms.Compose([
 
         cvtransforms.Resize(size=(416, 416), interpolation='BILINEAR'),
 
         cvtransforms.ToTensor()
 
         ])

    with get_engine(onnx_file_path, engine_file_path) as engine, engine.create_execution_context() as context:
        inputs, outputs, bindings, stream = common.allocate_buffers(engine)

        while cap.isOpened():
            retflag, frame = cap.read() 
            if retflag:
                #resize_frame = cv2.resize(frame, (416, 416), interpolation = cv2.INTER_AREA)
                img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                img = cv2.undistort(img, internal_calibration, dist, None, internal_calibration)
                yolo_img =cv2.resize(img, (416, 416), interpolation=cv2.INTER_AREA)
                box_pr_multi = do_detect_trt(context, yolo_img, conf_thresh, nms_thresh, bindings, inputs, outputs, stream)

                for box_pr in box_pr_multi:
                    corners2D_pr = np.array(np.reshape(box_pr[:18], [9, 2]), dtype='float32')           
                    corners2D_pr[:, 0] = corners2D_pr[:, 0] * 1280
                    corners2D_pr[:, 1] = corners2D_pr[:, 1] * 720
                    preds_corners2D.append(corners2D_pr)
                    
                    # Compute [R|t] by pnp
                    _, R_pr, t_pr = pnp(np.array(np.transpose(np.concatenate((np.zeros((3, 1)), corners3D[:3, :]), axis=1)), dtype='float32'),  corners2D_pr, np.array(internal_calibration, dtype='float32'))

                    corner2d_pr_vertices = []
                    index = 0

                    # not an empty array, AND, all corners are beyond a y threshold
                    if (corners2D_pr.size > 0) and are_corners_greater_than_y_thres(corners2D_pr, y_dispay_thresh):
                        ymin_pt = find_pt_with_smallest_y(corners2D_pr)
                        pt_for_label1= (int(ymin_pt[0]-30), int(ymin_pt[1]-30))
                        pt_for_label2 = (int(ymin_pt[0]-50), int(ymin_pt[1]-10))
                        for pt in corners2D_pr:
                            # print('corners2D_pr', pt)
                            x = pt[0]
                            y = pt[1]
                            pt =(x, y)
                            if y > y_dispay_thresh:
                                white = (255, 255, 255)
                                cv2.circle(frame, pt, 2, white, thickness=2, lineType=8, shift=0)
                                font = cv2.FONT_HERSHEY_SIMPLEX
                                color = (255, 255, 255)
                                font_scale = 0.6
                                pt_for_number = (int(x+5), int(y-5))
                                # only print the center point (index 0)
                                if index == 0:
                                    cv2.putText(frame, str(index), pt_for_number, font, font_scale, color, 2, lineType=8)
                                # skip the centroid, we only want the vertices
                                corner2d_pr_vertices.append(pt)

                                
                            index = index + 1
                        blue = (255,0,0)
                        # print x offset and z offset (depth) above the smallest y point
                        x = float(t_pr[0])
                        x_cord = 'x ' + str("{0:.2f}".format(x)) + 'm'
                        white = (255,255,255)
                        
                        x1 = pt_for_label1[0]
                        y1 = pt_for_label1[1]
                        purple = (132,37,78)
                        #cv2.rectangle(frame, (x1, y1-20), (x1+len(x_cord)*19+60,y1), purple, -1)
                        

                        z = float(t_pr[2])
                        z_cord = 'Depth ' + str("{0:.2f}".format(z)) + 'm'
                        
                        x2 = pt_for_label2[0]
                        y2 = pt_for_label2[1]                        
                        cv2.rectangle(frame, (x2-5, y2-20*2), (x2+len(z_cord)*12,y2+5), purple, -1)
                        cv2.putText(frame, x_cord, pt_for_label1, font, font_scale, white, 1, lineType=8)
                        cv2.putText(frame, z_cord, pt_for_label2, font, font_scale, white, 1, lineType=8)
                        draw_cube(frame, corner2d_pr_vertices)
                        # if z is less than zero; i.e. away from camera
                        if (t_pr[2] < 0):
                            print('x ', round(float(t_pr[0]), 2), 'y ', round(float(t_pr[1]), 2), 'z ', round(float(t_pr[2]), 2))

                if save:
                    preds_trans.append(t_pr)
                    preds_rot.append(R_pr)

                    np.savetxt(backupdir + '/test/pr/R_' + valid_files[count][-8:-3] + 'txt', np.array(R_pr, dtype='float32'))
                    np.savetxt(backupdir + '/test/pr/t_' + valid_files[count][-8:-3] + 'txt', np.array(t_pr, dtype='float32'))
                    np.savetxt(backupdir + '/test/pr/corners_' + valid_files[count][-8:-3] + 'txt', np.array(corners2D_pr, dtype='float32'))


                    # Compute 3D distances
                    transform_3d_pred = compute_transformation(vertices, Rt_pr)  
                    vertex_dist       = np.mean(norm3d)


                cv2.imshow('6D pose estimation', frame)
                detectedKey = cv2.waitKey(1) & 0xFF
                if detectedKey == ord('c'):
                    timestamp = time.time()
                    cv2.imwrite('./screenshots/screeshot' + str(timestamp) + '.jpg', frame)
                    print('captured screeshot')
                elif detectedKey == ord('q'):
                    print('quitting program')
                    break

                # if cv2.waitKey(1) & 0xFF == ord('q'):
                #     break

                t5 = time.time()
            else:
                break



            if False:
                print('-----------------------------------')
                print('  tensor to cuda : %f' % (t2 - t1))
                print('         predict : %f' % (t3 - t2))
                print('get_region_boxes : %f' % (t4 - t3))
                print('            eval : %f' % (t5 - t4))
                print('           total : %f' % (t5 - t1))
                print('-----------------------------------')


    if save:
        predfile = backupdir + '/predictions_linemod_' + name +  '.mat'
        scipy.io.savemat(predfile, {'R_prs': preds_rot, 't_prs':preds_trans, 'corner_prs': preds_corners2D})
Esempio n. 5
0
        cvtransforms.RandomHorizontalFlip(),
        cvtransforms.RandomVerticalFlip(),
        cvtransforms.RandomRotation(90),
        cvtransforms.ToTensor(),
        # cvtransforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]

    val_input_transform_list = [
        cvtransforms.Resize(size=input_tensor_res, interpolation='BILINEAR'),
        cvtransforms.ToTensor(),
        # cvtransforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]

    train_transforms = [
        compose_input_output_transform(
            input_transform=cvtransforms.Compose(train_input_transform_list)),
    ]

    base_dataset = mpImage_4C_sorted_by_patient_dataset(
        img_dir=args.img_path,
        multi_label_gt_path=gt_path,
        transform=train_transforms[0])

    num_classes = base_dataset[0]["gt"].shape[-1]
    # Split data into cross-validation_set
    # cv_split_list = nfold_cross_validation(len(train_dataset), n_fold=2)
    # cv_split_list = nfold_cross_validation(4, n_fold=2)
    # cv_split_list = leave_one_out_cross_validation(len(base_dataset))
    cv_split_list = leave_one_patient_out_cross_validation(
        len(base_dataset), patient_deid=base_dataset.patient_deid_list)
    # cv_split_list = leave_one_out_cross_validation(2)
                                  cvtransforms.RandomHorizontalFlip(),
                                  cvtransforms.RandomVerticalFlip(),
                                  cvtransforms.RandomRotation(90),
                                  cvtransforms.ToTensor(),
                                  # cvtransforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                  ]

    train_input_transform_list = [cvtransforms.ToTensor()]

    val_input_transform_list = [cvtransforms.Resize(size=input_tensor_res, interpolation='BILINEAR'),
                                cvtransforms.ToTensor(),
                                # cvtransforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                ]

    train_transforms = [
        compose_input_output_transform(input_transform=cvtransforms.Compose(train_input_transform_list)),
        ]

    base_dataset = mpImage_4C_sorted_by_patient_dataset(img_dir=args.datapath,
                                                     multi_label_gt_path=gt_path,
                                                     transform=train_transforms[0])



    num_classes = base_dataset[0]["gt"].shape[-1]
    # Split data into cross-validation_set
    # cv_split_list = nfold_cross_validation(len(train_dataset), n_fold=2)
    # cv_split_list = nfold_cross_validation(4, n_fold=2)
    # cv_split_list = leave_one_out_cross_validation(len(base_dataset))
    cv_split_list = leave_one_patient_out_cross_validation(len(base_dataset),
                                                           patient_deid=base_dataset.patient_deid_list)
Esempio n. 7
0
                                 cvtransforms.RandomVerticalFlip(),
                                 cvtransforms.RandomRotation(90),
                                 cvtransforms.ToTensor(),
                                 cvtransforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]

    # hflip_input_transform_list = [cvtransforms.Resize(size=input_tensor_size, interpolation='BILINEAR'),
    #                              cvtransforms.RandomHorizontalFlip(p=1),
    #                              cvtransforms.ToTensor(),
    #                              cvtransforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]



    # output_transform =

    # CV_data set
    train_val_transforms = [compose_input_output_transform(input_transform=cvtransforms.Compose(train_val_input_transform_list)),
                  ]

    train_val_dataset = torch.utils.data.ConcatDataset([
                mpImage_sorted_by_image_dataset(img_dir=args.datapath, gt_path=gt_path, transform=t) for t in train_val_transforms])


    num_classes = train_val_dataset[-1]["gt"].shape[-1]
    # Split data into cross-validation_set and test_set
    cv_split_indices, test_indices = cross_validation_and_test_split(len(train_val_dataset))
    print(cv_split_indices, test_indices)

    cv_data_samplers = [SubsetRandomSampler(cv_split_index) for cv_split_index in cv_split_indices]

    cv_data_loaders = [DataLoader(dataset=train_val_dataset, batch_size=args.n_batch, sampler=cv_data_sampler
                                  ) for cv_data_sampler in cv_data_samplers]
Esempio n. 8
0
            train = train.append(csv_train_category)

        if category_idx == 0:
            csv_test_category = csv_test_file[csv_test_file["code"] ==
                                              code_list[category_idx]]
            test = csv_test_category
        else:
            csv_test_category = csv_test_file[csv_test_file["code"] ==
                                              code_list[category_idx]]
            test = test.append(csv_test_category)

    print("Total train set number: %i" % (len(train)))
    print("Total test set number: %i" % (len(test)))

    transformation = {
        'train': cvtransforms.Compose([cvtransforms.Resize((256, 256))]),
        'test': cvtransforms.Compose([cvtransforms.Resize((256, 256))])
    }

    dataset_train = {
        x: classification_Dataset(basic_train_path=args.basic_train_path,
                                  basic_test_path=args.basic_test_path,
                                  csv_train=train,
                                  csv_test=test,
                                  code_dict=code_dict,
                                  transformation=transformation[x],
                                  mode=x)
        for x in ['train', 'test']
    }

    data_loader = {
def training_pipeline_per_fold(nth_trainer,
                               epochs,
                               nth_fold,
                               base_dataset_dict,
                               train_transform_list,
                               val_transform_list,
                               cv_splits,
                               gpu_count,
                               n_batch,
                               label_idx,
                               params_list=[]):

    cv_split = cv_splits[nth_fold]
    train_transform_list_temp = train_transform_list.copy()
    val_transform_list_temp = val_transform_list.copy()

    input_tensor_res = (nth_trainer.model_dict['input_size'][-2],
                        nth_trainer.model_dict['input_size'][-1])
    train_transform_list_temp.insert(
        0, cvtransforms.Resize(size=input_tensor_res,
                               interpolation='BILINEAR'))
    val_transform_list_temp.insert(
        0, cvtransforms.Resize(size=input_tensor_res,
                               interpolation='BILINEAR'))

    if torch.cuda.is_available():
        # gpu_list = os.environ['CUDA_VISIBLE_DEVICES'].split(',')
        # if gpu_count == 1:
        device = torch.device('cuda')
        # else:
        #     device = torch.device("cuda:{}".format(gpu_list[nth_fold % gpu_count]))
        # print('{}th fold using: {}, memomry:'.format(nth_fold, device, torch.cuda.get_device_properties(device).total_memory))
    else:
        device = torch.device('cpu')
    if not nth_trainer.train_data_normal:
        train_normal = cvtransforms.Normalize([0.485, 0.456, 0.406],
                                              [0.229, 0.224, 0.225])
    else:
        train_mean, train_std = get_normalization_mean_std_from_training_set(
            base_dataset_dict=base_dataset_dict,
            train_idx=cv_split[0],
            device=device,
            train_transform_list=train_transform_list_temp,
            n_batch=n_batch)
        train_normal = cvtransforms.Normalize(train_mean, train_std)
    # print(train_normal)
    train_transform_list_temp.append(train_normal)
    val_transform_list_temp.append(train_normal)
    # pow_set_training_list = power_set_training_transform(train_transform_list_temp)
    # train_transforms = [
    #     compose_input_output_transform(input_transform=cvtransforms.Compose(train_t)) for train_t in pow_set_training_list
    # ]
    train_transforms = [
        compose_input_output_transform(
            input_transform=cvtransforms.Compose(train_transform_list_temp)),
    ]

    # len_of_dataset = len(cv_split[0])+len(cv_split[1])
    # cv_train_idx = np.concatenate([cv_split[0]+n*len_of_dataset for n in range(len(train_transforms))], axis=0)
    # print(cv_split[0].dtype)
    train_data = torch.utils.data.ConcatDataset([
        base_dataset_dict["base_dataset"](
            img_dir=base_dataset_dict["datapath"],
            multi_label_gt_path=base_dataset_dict["gt_path"],
            transform=t) for t in train_transforms
    ])

    val_transforms = [
        compose_input_output_transform(
            input_transform=cvtransforms.Compose(val_transform_list_temp)),
    ]
    val_data = torch.utils.data.ConcatDataset([
        base_dataset_dict["base_dataset"](
            img_dir=base_dataset_dict["datapath"],
            multi_label_gt_path=base_dataset_dict["gt_path"],
            transform=t) for t in val_transforms
    ])

    train_data_loader = DataLoader(dataset=train_data,
                                   batch_size=n_batch,
                                   num_workers=0,
                                   sampler=SubsetRandomSampler(cv_split[0]))
    val_data_loader = DataLoader(dataset=val_data,
                                 batch_size=n_batch,
                                 num_workers=0,
                                 sampler=SubsetRandomSampler(cv_split[1]))

    print("{} {}th fold: {}".format("-" * 10, nth_fold, "-" * 10))
    nth_trainer.model_init()
    nth_trainer.model.to(device)
    running_loss = 0
    ran_data = 0
    running_states = ['train', 'val']
    for epoch in range(epochs):
        print("=" * 30)
        print("{} {}th fold {}th epoch running: {}".format(
            "=" * 10, nth_fold, epoch, "=" * 10))
        epoch_start_time = time.time()

        for running_state in running_states:
            state_start_time = time.time()
            if running_state == "train":
                cv_data_loader = train_data_loader
            else:
                cv_data_loader = val_data_loader
            for batch_idx, data in enumerate(cv_data_loader):
                # print(batch_idx)
                input = data['input']
                gt = data['gt'][..., label_idx].unsqueeze(-1)
                deid = data['deid']
                row_idx = data['row_idx']

                input = Variable(input).float().to(device)
                gt = Variable(gt).float().to(device)

                # input = Variable(input.view(-1, *(input.shape[2:]))).float().to(device)
                # gt = Variable(gt.view(-1, *(gt.shape[2:]))).float().to(device)

                loss, predict = nth_trainer.running_model(
                    input,
                    gt,
                    epoch=epoch,
                    running_state=running_state,
                    nth_fold=nth_fold,
                    deid=deid,
                    row_idx=row_idx)
                ran_data += 1
                running_loss += loss.item()

            state_time_elapsed = time.time() - state_start_time
            print(
                "{}th fold {}th epoch ({}) running time cost: {:.0f}m {:.0f}s".
                format(nth_fold, epoch, running_state,
                       state_time_elapsed // 60, state_time_elapsed % 60))
            print('{}th fold {}th epoch ({}) average loss: {}'.format(
                nth_fold, epoch, running_state, running_loss / ran_data))
            running_loss = 0
            ran_data = 0
        # print(loss)
        time_elapsed = time.time() - epoch_start_time

        print("{}{}th epoch running time cost: {:.0f}m {:.0f}s".format(
            "-" * 5, epoch, time_elapsed // 60, time_elapsed % 60))

    nth_trainer.model = None
    return nth_trainer