コード例 #1
0
    def preload_data(self):
        print('Pre-loading the data. This may take a while...')

        t = Timer()
        t.tic()
        self.is_preload = False
        npy_path = os.path.join(
            self.image_path, self.image_files[self.num_samples - 1]).replace(
                'images', 'gt_npy').replace('.jpg', '.npz')
        if os.path.isfile(npy_path):
            pass
        else:
            os.makedirs(self.image_path.replace('images', 'images_resized'))
            os.makedirs(self.image_path.replace('images', 'gt_npy'))
            self.blob_list = [_ for _ in range(self.num_samples)]
            for i in range(self.num_samples):
                img, den, count = self.load_index(i)
                den = den.astype(np.float32)
                image_path = os.path.join(self.image_path, self.image_files[i])
                img.save(image_path.replace('images', 'images_resized'),
                         quality=100)
                save_npy(
                    image_path.replace('images',
                                       'gt_npy').replace('.jpg', '.npz'), den)
                if i % 50 == 0:
                    print("loaded {}/{} samples".format(i, self.num_samples))
        duration = t.toc(average=False)
        print('Completed loading ', len(self.blob_list), ' files, time: ',
              duration)
        self.is_preload = True
コード例 #2
0
    def preload(self):

        self.ncls = {
            k: np.random.choice([0, 1], p=pro)
            for k, pro in self.pro_dict.items()
        }

        self.choices = {
            k: random.randint(0,
                              len(self.sample_dict[k][ncl]) - 1)
            for k, ncl in self.ncls.items()
        }

        self.loaded = {}
        indexs = [self.sample_dict[self.patch_list[i][0] + 1]\
                                    [self.ncls[self.patch_list[i][0] + 1]]\
                                    [self.choices[self.patch_list[i][0] + 1]]\
                                    [0] for i in range(self.get_num_samples())]
        indexs = set(indexs)
        load_timer = Timer()
        load_timer.tic()
        with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor:
            for blob in executor.map(lambda i: (i, self.dataloader[i]),
                                     indexs):
                self.loaded[blob[0]] = blob[1]
        print("re-prior crop: %f s" % load_timer.toc(average=False))
コード例 #3
0
def test_model_origin(net,
                      data_loader,
                      save_output=False,
                      save_path=None,
                      test_fixed_size=-1,
                      test_batch_size=1,
                      gpus=None):
    timer = Timer()
    timer.tic()
    net.eval()
    mae = 0.0
    mse = 0.0
    detail = ''
    if save_output:
        print save_path
    for i, blob in enumerate(data_loader.get_loader(test_batch_size)):
        if (i * len(gpus) + 1) % 100 == 0:
            print "testing %d" % (i + 1)
        if save_output:
            index, fname, data, mask, gt_dens, gt_count = blob
        else:
            index, fname, data, mask, gt_count = blob

        with torch.no_grad():
            dens = net(data)
            if save_output:
                image = data.squeeze_().mul_(torch.Tensor([0.229,0.224,0.225]).view(3,1,1))\
                                        .add_(torch.Tensor([0.485,0.456,0.406]).view(3,1,1)).data.cpu().numpy()

                dgen.save_image(
                    image.transpose((1, 2, 0)) * 255.0, save_path,
                    fname[0].split('.')[0] + "_0_img.png")
                gt_dens = gt_dens.data.cpu().numpy()
                density_map = dens.data.cpu().numpy()
                dgen.save_density_map(gt_dens.squeeze(), save_path,
                                      fname[0].split('.')[0] + "_1_gt.png")
                dgen.save_density_map(density_map.squeeze(), save_path,
                                      fname[0].split('.')[0] + "_2_et.png")
                _gt_count = gt_dens.sum().item()
                del gt_dens
        gt_count = gt_count.item()
        et_count = dens.sum().item()

        del data, dens
        detail += "index: {}; fname: {}; gt: {}; et: {};\n".format(
            i, fname[0].split('.')[0], gt_count, et_count)
        mae += abs(gt_count - et_count)
        mse += ((gt_count - et_count) * (gt_count - et_count))
    mae = mae / len(data_loader)
    mse = np.sqrt(mse / len(data_loader))
    duration = timer.toc(average=False)
    print "testing time: %d" % duration
    return mae, mse, detail
コード例 #4
0
def demo(symbol_name, params_path, img_dir):
    timer = Timer()
    model = get_symbol(symbol_name)()
    restore_weights(model, params_path)

    for img_path in Path(img_dir).glob('*.png'):
        img = Image.open(img_path)
        inputs = setup_data(img)
        timer.tic()
        num = eval(model, inputs)
        costs = timer.toc()
        print("the number in image %s is %d || forward costs %.4fs" %
              (img_path, num, costs))
    print("average costs: %.4fs" % (timer.average_time))
コード例 #5
0
    def preload_data(self):
        print ('Pre-loading the data. This may take a while...')

        t = Timer()
        t.tic()
        self.blob_list = [_ for _ in range(self.num_samples)]
        self.is_preload = False
        for i in range(self.num_samples):
            self.blob_list[i] = (self.load_index(i))
            if i % 50 == 0:
                print ("loaded {}/{} samples".format(i, self.num_samples))
        duration = t.toc(average=False)
        print ('Completed loading ' ,len(self.blob_list), ' files, time: ', duration)
        self.is_preload = True
コード例 #6
0
def train(net, train_loader,optimizer, num_epochs):
    log_file = open(args.SAVE_ROOT+"/"+args.Dataset+"_training.log","w",1)
    log_print("Training ....", color='green', attrs=['bold'])
    # training
    train_loss = 0
    step_cnt = 0
    re_cnt = False
    t = Timer()
    t.tic()
    for epoch in range(1,num_epochs+1):
        step = -1
        train_loss = 0
        for blob in train_loader:                
            step = step + 1        
            im_data = blob['data']
            gt_data = blob['gt_density']
            density_map = net(im_data, gt_data)
            loss = net.loss
            train_loss += loss.data
            step_cnt += 1
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            if step % disp_interval == 0:            
                duration = t.toc(average=False)
                fps = step_cnt / duration
                gt_count = np.sum(gt_data)    
                density_map = density_map.data.cpu().numpy()
                et_count = np.sum(density_map)
                utils.save_results(im_data,gt_data,density_map, args.SAVE_ROOT)
                log_text = 'epoch: %4d, step %4d, Time: %.4fs, gt_cnt: %4.1f, et_cnt: %4.1f' % (epoch,
                    step, 1./fps, gt_count,et_count)
                log_print(log_text, color='green', attrs=['bold'])
                re_cnt = True   
            if re_cnt:                                
                t.tic()
                re_cnt = False
    return net
コード例 #7
0
        train_loss = 0
        outer_timer = Timer()
        outer_timer.tic()
        '''regenerate crop patches'''
        data_loader_train.shuffle_list()

        load_timer = Timer()
        load_time = 0.0
        iter_timer = Timer()
        iter_time = 0.0
        for i, datas in enumerate(\
                    DataLoader(data_loader_train, batch_size=opt.batch_size, \
                                                    shuffle=True, num_workers=4,drop_last=True)):
            step_cnt += 1
            if i != 0:
                load_time += load_timer.toc(average=False)
            iter_timer.tic()
            img_data = datas[0]
            gt_data = datas[1]
            raw_patch = datas[2]
            gt_count = datas[3]
            fnames = [data_loader_train.query_fname(i) for i in datas[4]]
            batch_size = len(fnames)

            step = step + 1

            net.train()
            density_map = net(img_data, gt_data)
            net.backward(loss_scale)

            loss_value = float(net.loss.item())
コード例 #8
0
def test(net,test_path,optimizer, num_epochs, Dataset=args.Dataset):

    if Dataset=="fdst":
      num_sessions=3
      test_len=750
      low_limit=451
      high_limit=750
    else:
        num_sessions=8
        test_len=2000
        low_limit=1201
        high_limit=2000
    #print(num_sessions)


    sessions_list = []
    ses_size = 100
    
    for i in range(low_limit, high_limit,ses_size): 
      sessions_list.append(i)
    sessions_list.append(test_len)
    #print("test list: ", sessions_list)
    for test_inc in range(len(sessions_list)-1):
        start_frame = sessions_list[test_inc]
        end_frame = sessions_list[test_inc+1]
        #print('start:,end:', (start_frame,end_frame))

        test_loader = ImageDataLoader_Val_Test(test_path, None,'test_split',start_frame, end_frame, shuffle=False, gt_downsample=True, pre_load=True, Dataset=args.Dataset)
        log_file = open(args.SAVE_ROOT+"/"+args.Dataset+"_test.log","w",1)
        log_print("test/Self Training ....", color='green', attrs=['bold'])
        # training
        train_loss = 0
        step_cnt = 0
        re_cnt = False
        t = Timer()
        t.tic()
        for epoch in range(1,num_epochs+1):
            step = -1
            train_loss = 0
            for blob in test_loader:                
                step = step + 1        
                im_data = blob['data']
                net.training = False
                gt_data = net(im_data)
                gt_data = gt_data.cpu().detach().numpy()
                net.training = True
                density_map = net(im_data, gt_data)
                loss = net.loss
                train_loss += loss.data
                step_cnt += 1
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            
                if step % disp_interval == 0:            
                  duration = t.toc(average=False)
                  fps = step_cnt / duration
                  gt_count = np.sum(gt_data)    
                  density_map = density_map.data.cpu().numpy()
                  et_count = np.sum(density_map)
                  utils.save_results(im_data,gt_data,density_map, args.SAVE_ROOT)
                  log_text = 'epoch: %4d, step %4d, Time: %.4fs, gt_cnt: %4.1f, et_cnt: %4.1f' % (epoch,
                      step, 1./fps, gt_count,et_count)
                  log_print(log_text, color='green', attrs=['bold'])
                  re_cnt = True   
                if re_cnt:                                
                  t.tic()
                  re_cnt = False

        torch.backends.cudnn.enabled = True
        torch.backends.cudnn.benchmark = False 

        session= str(sessions_list[test_inc])
        network.save_net(args.SAVE_ROOT+'/'+args.Dataset+ session +'_self_trained_model_test.h5', net) 
        output_dir = './densitymaps/' + args.Dataset + session 
        net.cuda()
        net.eval()

        all_test_loader = ImageDataLoader(test_path, None, 'test_split', shuffle=False, gt_downsample=True, pre_load=True , Dataset=args.Dataset)

        for blob in all_test_loader:                        
            im_data = blob['data']
            net.training = False
            density_map = net(im_data)
            density_map = density_map.data.cpu().numpy()
            new_dm= density_map.reshape([ density_map.shape[2], density_map.shape[3] ])
            
            np.savetxt(output_dir + '_output_' + blob['fname'].split('.')[0] +'.csv', new_dm, delimiter=',', fmt='%.6f')

    return net
コード例 #9
0
            im_data = np.flip(im_data, 3).copy()
            gt_data = np.flip(gt_data, 3).copy()
        if np.random.uniform() > 0.5:
            #add random noise to the input image
            im_data = im_data + np.random.uniform(-10, 10, size=im_data.shape)

        density_map = net(im_data, gt_data, gt_class_label, class_wts)
        loss = net.loss
        train_loss += loss.data[0]
        step_cnt += 1
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if step % disp_interval == 0:
            duration = t.toc(average=False)
            fps = step_cnt / duration
            gt_count = np.sum(gt_data)
            density_map = density_map.data.cpu().numpy()
            et_count = np.sum(density_map)
            utils.save_results(im_data, gt_data, density_map, output_dir)
            log_text = 'epoch: %4d, step %4d, Time: %.4fs, gt_cnt: %4.1f, et_cnt: %4.1f' % (
                epoch, step, 1. / fps, gt_count, et_count)
            log_print(log_text, color='green', attrs=['bold'])
            re_cnt = True

        if re_cnt:
            t.tic()
            re_cnt = False

    if (epoch % 2 == 0):
コード例 #10
0
def test_model_patches(net,
                       data_loader,
                       save_output=False,
                       save_path=None,
                       test_fixed_size=-1,
                       test_batch_size=1,
                       gpus=None):
    timer = Timer()
    timer.tic()
    net.eval()
    mae = 0.0
    mse = 0.0
    detail = ''
    if save_output:
        print save_path
    for i, blob in enumerate(data_loader.get_loader(1)):

        if (i + 1) % 10 == 0:
            print "testing %d" % (i + 1)
        if save_output:
            index, fname, data, mask, gt_dens, gt_count = blob
        else:
            index, fname, data, mask, gt_count = blob

        data = data.squeeze_()
        if len(data.shape) == 3:
            'image small than crop size'
            data = data.unsqueeze_(dim=0)
        mask = mask.squeeze_()
        num_patch = len(data)
        batches = zip([
            i * test_batch_size
            for i in range(num_patch // test_batch_size +
                           int(num_patch % test_batch_size != 0))
        ], [(i + 1) * test_batch_size
            for i in range(num_patch // test_batch_size)] + [num_patch])
        with torch.no_grad():
            dens_patch = []
            for batch in batches:
                bat = data[slice(*batch)]
                dens = net(bat).cpu()
                dens_patch += [dens]

            if args.test_fixed_size != -1:
                H, W = mask.shape
                _, _, fixed_size = data[0].shape
                assert args.test_fixed_size == fixed_size
                density_map = torch.zeros((H, W))
                for dens_slice, (x, y) in zip(
                        itertools.chain(*dens_patch),
                        itertools.product(range(W / fixed_size),
                                          range(H / fixed_size))):
                    density_map[y * fixed_size:(y + 1) * fixed_size, x *
                                fixed_size:(x + 1) * fixed_size] = dens_slice
                H = mask.sum(dim=0).max().item()
                W = mask.sum(dim=1).max().item()
                density_map = density_map.masked_select(mask).view(H, W)
            else:
                density_map = dens_patch[0]

            gt_count = gt_count.item()
            et_count = density_map.sum().item()

            if save_output:
                image = data.mul_(torch.Tensor([0.229,0.224,0.225]).view(3,1,1))\
                                        .add_(torch.Tensor([0.485,0.456,0.406]).view(3,1,1))

                if args.test_fixed_size != -1:
                    H, W = mask.shape
                    _, _, fixed_size = data[0].shape
                    assert args.test_fixed_size == fixed_size
                    inital_img = torch.zeros((3, H, W))
                    for img_slice, (x, y) in zip(
                            image,
                            itertools.product(range(W / fixed_size),
                                              range(H / fixed_size))):
                        inital_img[:, y * fixed_size:(y + 1) * fixed_size, x *
                                   fixed_size:(x + 1) * fixed_size] = img_slice
                    H = mask.sum(dim=0).max().item()
                    W = mask.sum(dim=1).max().item()
                    inital_img = inital_img.masked_select(mask).view(3, H, W)
                    image = inital_img

                image = image.data.cpu().numpy()
                dgen.save_image(
                    image.transpose((1, 2, 0)) * 255.0, save_path,
                    fname[0].split('.')[0] + "_0_img.png")
                gt_dens = gt_dens.data.cpu().numpy()
                density_map = density_map.data.cpu().numpy()
                dgen.save_density_map(gt_dens.squeeze(), save_path,
                                      fname[0].split('.')[0] + "_1_gt.png")
                dgen.save_density_map(density_map.squeeze(), save_path,
                                      fname[0].split('.')[0] + "_2_et.png")
                del gt_dens
            del data, dens

        detail += "index: {}; fname: {}; gt: {}; et: {};\n".format(
            i, fname[0].split('.')[0], gt_count, et_count)
        mae += abs(gt_count - et_count)
        mse += ((gt_count - et_count) * (gt_count - et_count))
    mae = mae / len(data_loader)
    mse = np.sqrt(mse / len(data_loader))
    duration = timer.toc(average=False)
    print "testing time: %d" % duration
    return mae, mse, detail
コード例 #11
0
ファイル: test_nwpu_loc.py プロジェクト: WangyiNTU/SCALNet
def test_model_origin(net,
                      data_loader,
                      save_output=False,
                      save_path=None,
                      test_fixed_size=-1,
                      test_batch_size=1,
                      gpus=None,
                      args=None):
    timer = Timer()
    timer.tic()
    net.eval()
    mae = 0.0
    mse = 0.0
    detail = ''
    if args.save_txt:
        save_txt_path = save_path.replace('density_maps', 'loc_txt_test')
        if not os.path.exists(save_txt_path):
            os.mkdir(save_txt_path)
        record = open(
            save_txt_path +
            '/DLA_loc_test_thr_{:.02f}.txt'.format(args.det_thr), 'w+')
        record2 = open(save_txt_path + '/DLA_cnt_test_den.txt', 'w+')

    if save_output:
        print(save_path)
    for i, blob in enumerate(
            data_loader.get_loader(test_batch_size,
                                   num_workers=args.num_workers)):
        if (i * len(gpus) + 1) % 100 == 0:
            print("testing %d" % (i + 1))
        if save_output:
            index, fname, data, mask, gt_dens, gt_count = blob
        else:
            index, fname, data, mask, gt_count = blob

        if not args.test_patch:
            with torch.no_grad():
                dens, dm = net(data)
                dens = dens.sigmoid_()
                dens_nms = network._nms(dens.detach())
                dens_nms = dens_nms.data.cpu().numpy()
                dm = dm.data.cpu().numpy()
        else:  #TODO
            dens, dens_nms, dm = test_patch(data)
            dens_nms = dens_nms.data.cpu().numpy()
            dm = dm.data.cpu().numpy()

        dm[dm < 0] = 0.0
        gt_count = gt_count.item()
        # et_count = dens.sum().item()
        et_count = np.sum(
            dens_nms.reshape(test_batch_size, -1) >= args.det_thr, axis=-1)[0]
        et_count_dm = np.sum(dm.reshape(test_batch_size, -1), axis=-1)[0]

        if save_output:
            image = data.clone().squeeze_().mul_(torch.Tensor([0.229, 0.224, 0.225]).view(3, 1, 1)) \
                .add_(torch.Tensor([0.485, 0.456, 0.406]).view(3, 1, 1)).data.cpu().numpy()

            dgen.save_image(
                image.transpose((1, 2, 0)) * 255.0, save_path,
                fname[0].split('.')[0] + "_0_img.jpg")
            gt_dens = gt_dens.data.cpu().numpy()
            density_map = dens.data.cpu().numpy()
            dgen.save_density_map(gt_dens.squeeze(), save_path,
                                  fname[0].split('.')[0] + "_1_gt.jpg",
                                  gt_count)
            dgen.save_density_map(density_map.squeeze(), save_path,
                                  fname[0].split('.')[0] + "_2_et.jpg")
            dens_mask = dens_nms >= args.det_thr
            dgen.save_heatmep_pred(dens_mask.squeeze(), save_path,
                                   fname[0].split('.')[0] + "_3_et.jpg",
                                   et_count)
            _gt_count = gt_dens.sum().item()
            del gt_dens

        if args.save_txt:
            ori_img = Image.open(
                os.path.join(data_loader.dataloader.image_path, fname[0]))
            ori_w, ori_h = ori_img.size
            h, w = data.shape[2], data.shape[3]
            ratio_w = float(ori_w) / w
            ratio_h = float(ori_h) / h
            dens_nms[dens_nms >= args.det_thr] = 1
            dens_nms[dens_nms < args.det_thr] = 0
            ids = np.array(np.where(dens_nms == 1))  # y,x
            ori_ids_y = ids[2, :] * ratio_h + ratio_h / 2
            ori_ids_x = ids[3, :] * ratio_w + ratio_w / 2
            ids = np.vstack((ori_ids_x, ori_ids_y)).astype(np.int16)  # x,y

            loc_str = ''
            for i_id in range(ids.shape[1]):
                loc_str = loc_str + ' ' + str(ids[0][i_id]) + ' ' + str(
                    ids[1][i_id])  # x, y
            if i == len(data_loader) - 1:
                record.write('{filename} {pred:d}{loc_str}'.format(
                    filename=fname[0].split('.')[0],
                    pred=et_count,
                    loc_str=loc_str))
                record2.write('{filename} {pred:0.2f}'.format(
                    filename=fname[0].split('.')[0], pred=float(et_count_dm)))
            else:
                record.write('{filename} {pred:d}{loc_str}\n'.format(
                    filename=fname[0].split('.')[0],
                    pred=et_count,
                    loc_str=loc_str))
                record2.write('{filename} {pred:0.2f}\n'.format(
                    filename=fname[0].split('.')[0], pred=float(et_count_dm)))

        del data, dens
        detail += "index: {}; fname: {}; gt: {}; et: {}; dif: {};\n".format(
            i, fname[0].split('.')[0], gt_count, et_count, gt_count - et_count)
        mae += abs(gt_count - et_count)
        mse += ((gt_count - et_count) * (gt_count - et_count))
    mae = mae / len(data_loader)
    mse = np.sqrt(mse / len(data_loader))
    duration = timer.toc(average=False)
    if args.save_txt:
        record.close()
    print("testing time: %d" % duration)
    return mae, mse, detail
コード例 #12
0
def main():
    # define output folder
    output_dir = './saved_models/'
    log_dir = './mae_mse/'
    checkpoint_dir = './checkpoint/'

    train_path = '/home/jake/Desktop/Projects/Python/dataset/SH_B/cooked/train/images'
    train_gt_path = '/home/jake/Desktop/Projects/Python/dataset/SH_B/cooked/train/ground_truth'
    val_path = '/home/jake/Desktop/Projects/Python/dataset/SH_B/cooked/val/images'
    val_gt_path = '/home/jake/Desktop/Projects/Python/dataset/SH_B/cooked/val/ground_truth'

    # last checkpoint
    checkpointfile = os.path.join(checkpoint_dir, 'checkpoint.94.pth.tar')

    # some description
    method = 'mcnn'
    dataset_name = 'SH_B'

    # log file
    f_train_loss = open(os.path.join(log_dir, "train_loss.csv"), "a+")
    f_val_loss = open(os.path.join(log_dir, "val_loss.csv"), "a+")

    # Training configuration
    start_epoch = 0
    end_epoch = 97
    lr = 0.00001
    # momentum = 0.9
    disp_interval = 1000
    # log_interval = 250

    # Flag
    CONTINUE_TRAIN = True
    # Tensorboard  config

    # use_tensorboard = False
    # save_exp_name = method + '_' + dataset_name + '_' + 'v1'
    # remove_all_log = False   # remove all historical experiments in TensorBoard
    # exp_name = None # the previous experiment name in TensorBoard

    # -----------------------------------------------------------------------------------------
    rand_seed = 64678
    if rand_seed is not None:
        np.random.seed(rand_seed)
        torch.manual_seed(rand_seed)
        torch.cuda.manual_seed(rand_seed)

    # Define network
    net = CrowdCounter()
    network.weights_normal_init(net, dev=0.01)
    # net.cuda()
    net.train()
    # params = list(net.parameters())
    optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, net.parameters()), lr=lr)

    if not os.path.exists(output_dir):
        os.mkdir(output_dir)

    # # tensorboad
    # use_tensorboard = use_tensorboard and CrayonClient is not None
    # if use_tensorboard:
    #     cc = CrayonClient(hostname='127.0.0.1')
    #     if remove_all_log:
    #         cc.remove_all_experiments()
    #     if exp_name is None:
    #         exp_name = save_exp_name
    #         exp = cc.create_experiment(exp_name)
    #     else:
    #         exp = cc.open_experiment(exp_name)

    # training param

    if CONTINUE_TRAIN:
        net, optimizer, start_epoch = utils.load_checkpoint(
            net, optimizer, filename=checkpointfile)

    train_loss = 0
    step_cnt = 0
    re_cnt = False
    t = Timer()
    t.tic()

    # Load data
    data_loader = ImageDataLoader(
        train_path, train_gt_path, shuffle=True, gt_downsample=True, pre_load=True)
    data_loader_val = ImageDataLoader(
        val_path, val_gt_path, shuffle=False, gt_downsample=True, pre_load=True)
    best_mae = sys.maxsize

    # Start training

    for this_epoch in range(start_epoch, end_epoch-1):
        step = -1
        train_loss = 0
        for blob in data_loader:
            step += 1
            img_data = blob['data']
            gt_data = blob['gt_density']
            et_data = net(img_data, gt_data)
            loss = net.loss
            train_loss += loss.data
            step_cnt += 1
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if step % disp_interval == 0:
                duration = t.toc(average=False)
                fps = step_cnt / duration
                gt_count = np.sum(gt_data)
                et_data = et_data.data.cpu().numpy()
                et_count = np.sum(et_data)
                utils.save_results(img_data, gt_data, et_data, output_dir,
                                   fname="{}.{}.png".format(this_epoch, step))
                log_text = 'epoch: %4d, step %4d, Time: %.4fs, gt_cnt: %4.1f, et_cnt: %4.1f' % (this_epoch,
                                                                                                step, 1./fps, gt_count, et_count)
                log_print(log_text, color='green', attrs=['bold'])
                re_cnt = True

            if re_cnt:
                t.tic()
                re_cnt = False

        # Save checkpoint
        state = {'epoch': this_epoch, 'state_dict': net.state_dict(),
                 'optimizer': optimizer.state_dict()}
        cp_filename = "checkpoint.{}.pth.tar".format(this_epoch)
        torch.save(state, os.path.join(checkpoint_dir, cp_filename))
# ========================== END 1 EPOCH==================================================================================
        train_mae, train_mse = evaluate_network(net, data_loader)
        f_train_loss.write("{},{}\n".format(train_mae, train_mse))
        log_text = 'TRAINING - EPOCH: %d, MAE: %.1f, MSE: %0.1f' % (
            this_epoch, train_mae, train_mse)
        log_print(log_text, color='green', attrs=['bold'])
# =====================================================VALIDATION=========================================================
        # calculate error on the validation dataset
        val_mae, val_mse = evaluate_network(net, data_loader_val)
        f_val_loss.write("{},{}\n".format(val_mae, val_mse))
        log_text = 'VALIDATION - EPOCH: %d, MAE: %.1f, MSE: %0.1f' % (
            this_epoch, val_mae, val_mse)
        log_print(log_text, color='green', attrs=['bold'])
        # SAVE model
        is_save = False
        if val_mae <= best_mae:
            if val_mae < best_mae:
                is_save = True
                best_mae = val_mae
                best_mse = val_mse
            else:
                if val_mse < best_mse:
                    is_save = True
                    best_mse = val_mse

        if is_save:
            save_name = os.path.join(output_dir, '{}_{}_{}.h5'.format(
                method, dataset_name, this_epoch))
            network.save_net(save_name, net)
            best_model = '{}_{}_{}.h5'.format(method, dataset_name, this_epoch)
            log_text = 'BEST MAE: %0.1f, BEST MSE: %0.1f, BEST MODEL: %s' % (
                best_mae, best_mse, best_model)
            log_print(log_text, color='green', attrs=['bold'])

        # if use_tensorboard:
        #     exp.add_scalar_value('MAE', mae, step=epoch)
        #     exp.add_scalar_value('MSE', mse, step=epoch)
        #     exp.add_scalar_value('train_loss', train_loss /
        #                          data_loader.get_num_samples(), step=epoch)

    f_train_loss.close()
    f_val_loss.close()