Beispiel #1
0
 def get_item(self, index, shift=None):
     ims, tars, meta = [], [], {}
     meta['do_not_collate'] = True
     fps = 24
     n = self.data['datas'][index]['n']
     if shift is None:
         shift = np.random.randint(n - self.train_gap - 2)
     else:
         shift = int(shift * (n - self.train_gap - 2))
     resize = transforms.Resize(int(256. / 224 * self.input_size))
     normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                      std=[0.229, 0.224, 0.225])
     spacing = np.arange(shift, shift + self.train_gap)
     for loc in spacing:
         ii = int(np.floor(loc))
         path = '{}{:06d}.jpg'.format(self.data['datas'][index]['base'],
                                      ii + 1)
         try:
             # ============ Temp ===================
             timer = Timer()
             img = default_loader(path)
             # ============ Temp ===================
             load_img_cost = timer.thetime() - timer.end
             timer.tic()
             print(
                 'Load image from disk: {0:.3f} sec'.format(load_img_cost))
         except Exception as e:
             print('failed to load image {}'.format(path))
             print(e)
             raise
         img = resize(img)
         img = transforms.ToTensor()(img)
         # ============ Temp ===================
         # totensor_cost = timer.thetime() - timer.end
         # timer.tic()
         # print('From PIL to tensor: {0:.3f} sec'.format(totensor_cost))
         #img = 2*img - 1
         img = normalize(img)
         ims.append(img)
         target = torch.IntTensor(self.num_classes).zero_()
         for x in self.data['datas'][index]['labels']:
             if x['start'] < ii / float(fps) < x['end']:
                 target[self.cls2int(x['class'])] = 1
         tars.append(target)
     meta['id'] = self.data['datas'][index]['id']
     meta['time'] = shift
     img = torch.stack(ims).permute(0, 2, 3, 1).numpy()
     target = torch.stack(tars)
     if self.transform is not None:
         img = self.transform(img)
         # ============ Temp ===================
         # transform_cost = timer.thetime() - timer.end
         # timer.tic()
         # print('Image transform per mini-batch: {0:.3f} sec'.format(transform_cost))
     if self.target_transform is not None:
         target = self.target_transform(target)
     # batch will be b x n x h x w x c
     # target will be b x n x nc
     return img, target, meta
Beispiel #2
0
def train(loader, model, optimizer, epoch, args):
    timer = Timer()
    data_time = AverageMeter()
    loss_meter = AverageMeter()
    ce_loss_meter = AverageMeter()
    cur_lr = adjust_learning_rate(args.lr_decay_rate, optimizer, epoch)
    model.train()
    optimizer.zero_grad()
    ce_loss_criterion = nn.CrossEntropyLoss()
    for i, (input, meta) in tqdm(enumerate(loader), desc="Train Epoch"):
        if args.debug and i >= debug_short_train_num:
            break
        data_time.update(timer.thetime() - timer.end)

        _batch_size = len(meta)
        target = []
        for _ in range(_batch_size):
            target.extend(meta[_]["labels"])
        target = torch.from_numpy(np.array(target))
        input = input.view(
            _batch_size * 3,
            input.shape[2],
            input.shape[3],
            input.shape[4],
            input.shape[5],
        )
        metric_feat, output = model(input)
        ce_loss = ce_loss_criterion(output.cuda(), target.long().cuda())
        loss = ce_loss

        loss.backward()
        loss_meter.update(loss.item())
        ce_loss_meter.update(ce_loss.item())
        if i % args.accum_grad == args.accum_grad - 1:
            optimizer.step()
            optimizer.zero_grad()

        if i % args.print_freq == 0 and i > 0:
            logger.info("[{0}][{1}/{2}]\t"
                        "Dataload_Time={data_time.avg:.3f}\t"
                        "Loss={loss.avg:.4f}\t"
                        "CELoss={ce_loss.avg:.4f}\t"
                        "LR={cur_lr:.7f}\t"
                        "bestAP={ap:.3f}".format(
                            epoch,
                            i,
                            len(loader),
                            data_time=data_time,
                            loss=loss_meter,
                            ce_loss=ce_loss_meter,
                            ap=args.best_score,
                            cur_lr=cur_lr,
                        ))
            loss_meter.reset()
            ce_loss_meter.reset()
def read_multi_images(target_dir):
    timer = Timer()
    image_num = 0
    for f in listdir(target_dir):
        file_full_path = join(target_dir, f)
        if isfile(file_full_path):
            read_one_iamge(file_full_path)
            image_num += 1

    time_cost = timer.thetime() - timer.end
    print('Load images from disk: {0:.3f} sec'.format(time_cost))
    return time_cost, image_num
Beispiel #4
0
    def train(self, loader, model, criterion, optimizer, epoch, metrics, args, validate=False):
        timer = Timer()
        data_time = AverageMeter()
        losses = AverageMeter()
        metrics = [m() for m in metrics]

        if validate:
            # switch to evaluate mode
            model.eval()
            criterion.eval()
            iter_size = args.val_size
            setting = 'Validate Epoch'
        else:
            # switch to train mode
            adjust_learning_rate(args.lr, args.lr_decay_rate, optimizer, epoch)
            model.train()
            criterion.train()
            optimizer.zero_grad()
            iter_size = args.train_size
            setting = 'Train Epoch'

        for i, (input, target, meta) in enumerate(part(loader, iter_size)):
            if args.synchronous:
                assert meta['id'][0] == meta['id'][1], "dataset not synced"
            data_time.update(timer.thetime() - timer.end)

            if not args.cpu:
                target = target.cuda(non_blocking=True)
            output = model(input, meta)
            if type(output) != tuple:
                output = (output,)
            scores, loss, score_target = criterion(*(output + (target, meta)))
            losses.update(loss.item())
            with torch.no_grad():
                for m in metrics:
                    m.update(scores, score_target)

            if not validate:
                loss.backward()
                if i % args.accum_grad == args.accum_grad-1:
                    print('updating parameters')
                    optimizer.step()
                    optimizer.zero_grad()

            timer.tic()
            if i % args.print_freq == 0:
                print('[{name}] {setting}: [{0}][{1}/{2}({3})]\t'
                      'Time {timer.val:.3f} ({timer.avg:.3f})\t'
                      'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      '{metrics}'.format(
                          epoch, i, int(len(loader)*iter_size), len(loader),
                          name=args.name, setting=setting, timer=timer,
                          data_time=data_time, loss=losses,
                          metrics=' \t'.join(str(m) for m in metrics)))
            del loss, output, target  # make sure we don't hold on to the graph

        metrics = dict(m.compute() for m in metrics)
        metrics.update({'loss': losses.avg})
        metrics = dict(('val_'+k, v) if validate else ('train_'+k, v) for k, v in metrics.items())
        return metrics
""" Video loader for the Charades dataset """

from datasets import utils
from misc_utils.utils import Timer

# path = '/home/SERILOCAL/xiatian.zhu/Data/test_video/ZZXQF-000002.jpg'
path = '/home/nfs/x.chang/Datasets/Charades/Charades/Charades_v1_rgb/ZZXQF/ZZXQF-000002.jpg'

for i in range(10):
    try:
        # ============ Temp ===================
        timer = Timer()
        img = utils.default_loader(path)
        # ============ Temp ===================
        load_img_cost = timer.thetime() - timer.end
        timer.tic()
        print('Load image from disk: {0:.3f} sec'.format(load_img_cost))
    except Exception as e:
        print('failed to load image {}'.format(path))
        print(e)
        raise