Пример #1
0
def train_linear_one_epoch(train_loader, model, criterion, optimizer, config,
                           device):
    log_header = 'EPOCH {}'.format(epoch + 1)
    losses = AverageMeter('Loss', fmt=':.4f')
    top1 = AverageMeter('Top1', fmt=':4.2f')
    top5 = AverageMeter('Top5', fmt=':4.2f')
    lr = AverageMeter('Lr', fmt=":.4f")

    metric_logger = MetricLogger(delimeter=" | ")
    metric_logger.add_meter(losses)
    metric_logger.add_meter(top1)
    metric_logger.add_meter(top5)
    metric_logger.add_meter(lr)

    for step, (img, target) in enumerate(
            metric_logger.log_every(train_loader, config.system.print_freq,
                                    log_header)):
        img = img.to(device)
        target = target.to(device)
        logit = model_sl(img)

        loss = criterion(logit, target)
        acc1, acc5 = accuracy(logit, target, topk=(1, 5))
        lr_ = optimizer.param_groups[0]['lr']

        metric_logger.update(Loss=loss.detach().cpu().item(),
                             Top1=acc1.detach().cpu().item(),
                             Top5=acc5.detach().cpu().item(),
                             Lr=lr_)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
Пример #2
0
def evaluate(model, criterion, data_loader, device):
    model.eval()
    metric_logger = MetricLogger(delimiter="  ")
    header = 'Test:'
    with torch.no_grad():
        for video, target in metric_logger.log_every(data_loader, 100, header):
            start_time = time.time()
            video = video.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)
            output = model(video)
            time_diff = time.time() - start_time
            print("Predicting on a video of shape {} took {} seconds".format(
                video.shape, time_diff))
            print("target shape {}".format(target.shape))
            print("target {}".format(target))
            loss = criterion(output, target)

            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            # FIXME need to take into account that the datasets
            # could have been padded in distributed setup
            batch_size = video.shape[0]
            metric_logger.update(loss=loss.item())
            metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
            metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)

    print(
        ' * Clip Acc@1 {top1.global_avg:.3f} Clip Acc@5 {top5.global_avg:.3f}'.
        format(top1=metric_logger.acc1, top5=metric_logger.acc5))
    return metric_logger.acc1.global_avg
Пример #3
0
def train_one_epoch(model,
                    optimizer,
                    lr_scheduler,
                    data_loader,
                    epoch,
                    print_freq,
                    checkpoint_fn=None):
    model.train()
    metric_logger = MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value}'))
    metric_logger.add_meter('batch/s',
                            SmoothedValue(window_size=10, fmt='{value:.3f}'))

    header = 'Epoch: [{}]'.format(epoch)

    for step, batched_inputs in enumerate(
            metric_logger.log_every(data_loader, print_freq, header)):
        start_time = time.time()
        loss = model(batched_inputs)

        if checkpoint_fn is not None and np.random.random() < 0.005:
            checkpoint_fn()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        metric_logger.update(loss=loss.item(),
                             lr=optimizer.param_groups[0]["lr"])
        metric_logger.meters['batch/s'].update((time.time() - start_time))
        lr_scheduler.step()

    if checkpoint_fn is not None:
        checkpoint_fn()
Пример #4
0
def evaluate(model, epoch, criterion, data_loader, device, writer):
    model.eval()
    metric_logger = MetricLogger(delimiter="  ")
    header = 'Test:'
    cntr = 0
    running_accuracy = 0.0
    with torch.no_grad():
        for video, target in metric_logger.log_every(data_loader, 100, header):
            video = video.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)
            output = model(video)
            loss = criterion(output, target)

            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            # FIXME need to take into account that the datasets
            # could have been padded in distributed setup
            batch_size = video.shape[0]
            running_accuracy += acc1.item()
            if cntr % 10 == 9:  # average loss over the accumulated mini-batch
                writer.add_scalar('validation accuracy',
                                  running_accuracy / 10,
                                  epoch * len(data_loader) + cntr)
                running_accuracy = 0.0
            cntr += 1
            metric_logger.update(loss=loss.item())
            metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
            metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)

    print(' * Clip Acc@1 {top1.global_avg:.3f} Clip Acc@5 {top5.global_avg:.3f}'
          .format(top1=metric_logger.acc1, top5=metric_logger.acc5))
    return metric_logger.acc1.global_avg
Пример #5
0
    def train_eval_mse(self, model, valset, writer, global_step, device):
        """
        Evaluate MSE during training
        """
        num_samples = eval_cfg.train.num_samples.mse
        batch_size = eval_cfg.train.batch_size
        num_workers = eval_cfg.train.num_workers

        model.eval()

        if valset == None:
            data_set = np.load('../data/TABLE/val/all_set_val.npy')
            data_size = len(data_set)
            # creates indexes and shuffles them. So it can acces the data
            idx_set = np.arange(data_size)
            np.random.shuffle(idx_set)
            idx_set = idx_set[:num_samples]
            idx_set = np.split(idx_set, len(idx_set) / batch_size)
            data_to_enumerate = idx_set
        else:
            valset = Subset(valset, indices=range(num_samples))
            dataloader = DataLoader(valset,
                                    batch_size=batch_size,
                                    num_workers=num_workers,
                                    shuffle=False)
            data_to_enumerate = dataloader

        metric_logger = MetricLogger()

        print(f'Evaluating MSE using {num_samples} samples.')
        with tqdm(total=num_samples) as pbar:
            for batch_idx, sample in enumerate(data_to_enumerate):
                if valset == None:
                    data_i = data_set[sample]
                    data_i = torch.from_numpy(data_i).float().to(device)
                    data_i /= 255
                    data_i = data_i.permute([0, 3, 1, 2])
                    imgs = data_i
                else:
                    imgs = sample.to(device)
                loss, log = model(imgs, global_step)
                B = imgs.size(0)
                for b in range(B):
                    metric_logger.update(mse=log['mse'][b], )
                metric_logger.update(loss=loss.mean())
                pbar.update(B)

        assert metric_logger['mse'].count == num_samples
        # Add last log
        # log.update([(k, torch.tensor(v.global_avg)) for k, v in metric_logger.values.items()])
        mse = metric_logger['mse'].global_avg
        writer.add_scalar(f'val/mse', mse, global_step=global_step)

        model.train()

        return mse
Пример #6
0
def train_one_epoch(train_loader, model, criterion, optimizer, writer, epoch,
                    total_step, config):
    log_header = 'EPOCH {}'.format(epoch)
    losses = AverageMeter('Loss', fmt=':.4f')
    if config.method != 'byol':
        top1 = AverageMeter('Acc1', fmt=':4.2f')
        top5 = AverageMeter('Acc5', fmt=':4.2f')
    lr = AverageMeter('Lr', fmt=":.6f")

    metric_logger = MetricLogger(delimeter=" | ")
    metric_logger.add_meter(losses)
    if config.method != 'byol':
        metric_logger.add_meter(top1)
        metric_logger.add_meter(top5)
    metric_logger.add_meter(lr)
    # ce = nn.CrossEntropyLoss().cuda(config.system.gpu)
    # num_steps_per_epoch = int(len(train_loader.dataset) // config.train.batch_size)
    # global_step = num_steps_per_epoch * epoch
    for step, (images, _) in enumerate(
            metric_logger.log_every(train_loader, config.system.print_freq,
                                    log_header)):
        total_step.val += 1
        if config.system.gpu is not None:
            images[0] = images[0].cuda(config.system.gpu, non_blocking=True)
            images[1] = images[1].cuda(config.system.gpu, non_blocking=True)

        # [pos, neg]
        # output = model(view_1=images[0], view_2=images[1])
        # loss, logits, targets = criterion(output)
        if config.method != 'byol':
            logits, targets, logits_original = model(view_1=images[0],
                                                     view_2=images[1])
            loss = criterion(logits, targets)
            acc1, acc5 = accuracy(logits_original, targets, topk=(1, 5))
        else:
            loss_pre = model(view_1=images[0], view_2=images[1])
            loss = loss_pre.mean()

        lr_ = optimizer.param_groups[0]['lr']

        if config.method != 'byol':
            metric_logger.update(Loss=loss.detach().cpu().item(),
                                 Acc1=acc1.detach().cpu().item(),
                                 Acc5=acc5.detach().cpu().item(),
                                 Lr=lr_)
        else:
            metric_logger.update(Loss=loss.detach().cpu().item(), Lr=lr_)

        writer.add_scalar('loss', loss.detach().cpu().item(), total_step.val)
        if config.method != 'byol':
            writer.add_scalar('top1',
                              acc1.detach().cpu().item(), total_step.val)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
Пример #7
0
def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq):
    model.train()
    metric_logger = MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr',
                            SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Epoch: [{}]'.format(epoch)

    lr_scheduler = None
    if epoch == 0:
        warmup_factor = 1. / 1000
        warmup_iters = min(1000, len(data_loader) - 1)

        lr_scheduler = warmup_lr_scheduler(optimizer, warmup_iters,
                                           warmup_factor)

    for images, targets in metric_logger.log_every(data_loader, print_freq,
                                                   header):
        images = list(image.to(device) for image in images)
        #images = list(np.array(img) for img in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        loss_dict = model(images, targets)

        losses = sum(loss for loss in loss_dict.values())

        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = reduce_dict(loss_dict)
        losses_reduced = sum(loss for loss in loss_dict_reduced.values())

        loss_value = losses_reduced.item()

        if not math.isfinite(loss_value):
            print("Loss is {}, stopping training".format(loss_value))
            print(loss_dict_reduced)
            sys.exit(1)

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        if lr_scheduler is not None:
            lr_scheduler.step()

        metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])

    return metric_logger
Пример #8
0
def evaluate(model, data_loader, device):
    n_threads = torch.get_num_threads()
    # FIXME remove this and make paste_masks_in_image run on the GPU
    torch.set_num_threads(1)
    cpu_device = torch.device("cpu")
    model.eval()
    metric_logger = MetricLogger(delimiter="  ")
    header = 'Test:'

    coco = get_coco_api_from_dataset(data_loader.dataset)
    iou_types = _get_iou_types(model)
    coco_evaluator = CocoEvaluator(coco, iou_types)

    for images, targets in metric_logger.log_every(data_loader, 100, header):
        images = list(img.to(device) for img in images)

        if torch.cuda.is_available():
            torch.cuda.synchronize()
        model_time = time.time()
        outputs = model(images)

        outputs = [{k: v.to(cpu_device)
                    for k, v in t.items()} for t in outputs]
        model_time = time.time() - model_time

        res = {
            target["image_id"].item(): output
            for target, output in zip(targets, outputs)
        }
        evaluator_time = time.time()
        coco_evaluator.update(res)
        evaluator_time = time.time() - evaluator_time
        metric_logger.update(model_time=model_time,
                             evaluator_time=evaluator_time)

    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    coco_evaluator.synchronize_between_processes()

    # accumulate predictions from all images
    coco_evaluator.accumulate()
    coco_evaluator.summarize()
    torch.set_num_threads(n_threads)
    return coco_evaluator
Пример #9
0
def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader,
                    device, epoch, print_freq, writer):
    model.train()
    metric_logger = MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value}'))
    metric_logger.add_meter('clips/s',
                            SmoothedValue(window_size=10, fmt='{value:.3f}'))
    running_loss = 0.0
    running_accuracy = 0.0
    header = 'Epoch: [{}]'.format(epoch)
    cntr = 0
    for video, target in metric_logger.log_every(data_loader, print_freq,
                                                 header):
        start_time = time.time()
        video, target = video.to(device), target.to(device)
        output = model(video)
        loss = criterion(output, target)

        optimizer.zero_grad()

        loss.backward()
        optimizer.step()

        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        batch_size = video.shape[0]
        running_loss += loss.item()
        running_accuracy += acc1.item()
        if cntr % 10 == 9:  #average loss over the accumulated mini-batch
            writer.add_scalar('training loss', running_loss / 10,
                              epoch * len(data_loader) + cntr)
            writer.add_scalar('learning rate', optimizer.param_groups[0]["lr"],
                              epoch * len(data_loader) + cntr)
            writer.add_scalar('accuracy', running_accuracy / 10,
                              epoch * len(data_loader) + cntr)
            running_loss = 0.0
            running_accuracy = 0.0
        cntr = cntr + 1
        metric_logger.update(loss=loss.item(),
                             lr=optimizer.param_groups[0]["lr"])
        metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
        metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
        metric_logger.meters['clips/s'].update(batch_size /
                                               (time.time() - start_time))
        lr_scheduler.step()
Пример #10
0
    def train_eval_mse(self, model, valset, writer, global_step, device):
        """
        Evaluate MSE during training
        """
        num_samples = eval_cfg.train.num_samples.mse
        batch_size = eval_cfg.train.batch_size
        num_workers = eval_cfg.train.num_workers
        
        model.eval()
        # valset = Subset(valset, indices=range(num_samples))
        dataloader = DataLoader(valset, batch_size=batch_size, shuffle=True, num_workers=0,
                                              drop_last=True, collate_fn=valset.collate_fn)
        #dataloader = DataLoader(valset, batch_size=batch_size, num_workers=num_workers, shuffle=False)
    
        metric_logger = MetricLogger()
    
        print(f'Evaluating MSE using {num_samples} samples.')
        n_batch = 0
        with tqdm(total=num_samples) as pbar:
            for batch_idx, sample in enumerate(dataloader):
                imgs = sample[0].to(device)
                loss, log = model(imgs, global_step)
                B = imgs.size(0)
                for b in range(B):
                    metric_logger.update(
                        mse=log['mse'][b],
                    )
                metric_logger.update(loss=loss.mean())
                pbar.update(1)
                n_batch += 1
                if n_batch >= num_samples: break

        # assert metric_logger['mse'].count == num_samples
        # Add last log
        # log.update([(k, torch.tensor(v.global_avg)) for k, v in metric_logger.values.items()])
        mse = metric_logger['mse'].global_avg
        writer.add_scalar(f'val/mse', mse, global_step=global_step)
    
        model.train()
        
        return mse
Пример #11
0
def main():

    resume = True
    path = 'data/NYU_DEPTH'
    batch_size = 16
    epochs = 10000
    device = torch.device('cuda:0')
    print_every = 5
    # exp_name = 'resnet18_nodropout_new'
    exp_name = 'only_depth'
    # exp_name = 'normal_internel'
    # exp_name = 'sep'
    lr = 1e-5
    weight_decay = 0.0005
    log_dir = os.path.join('logs', exp_name)
    model_dir = os.path.join('checkpoints', exp_name)
    val_every = 16
    save_every = 16


    # tensorboard
    # remove old log is not to resume
    if not resume:
        if os.path.exists(log_dir):
            shutil.rmtree(log_dir)
            os.makedirs(log_dir)
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    tb = SummaryWriter(log_dir)
    tb.add_custom_scalars({
        'metrics': {
            'thres_1.25': ['Multiline', ['thres_1.25/train', 'thres_1.25/test']],
            'thres_1.25_2': ['Multiline', ['thres_1.25_2/train', 'thres_1.25_2/test']],
            'thres_1.25_3': ['Multiline', ['thres_1.25_3/train', 'thres_1.25_3/test']],
            'ard': ['Multiline', ['ard/train', 'ard/test']],
            'srd': ['Multiline', ['srd/train', 'srd/test']],
            'rmse_linear': ['Multiline', ['rmse_linear/train', 'rmse_linear/test']],
            'rmse_log': ['Multiline', ['rmse_log/train', 'rmse_log/test']],
            'rmse_log_invariant': ['Multiline', ['rmse_log_invariant/train', 'rmse_log_invariant/test']],
        }
    })
    
    
    # data loader
    dataset = NYUDepth(path, 'train')
    dataloader = DataLoader(dataset, batch_size, shuffle=True, num_workers=4)
    
    dataset_test = NYUDepth(path, 'test')
    dataloader_test = DataLoader(dataset_test, batch_size, shuffle=True, num_workers=4)
    
    
    # load model
    model = FCRN(True)
    model = model.to(device)
    
    
    # optimizer
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    
    start_epoch = 0
    if resume:
        model_path = os.path.join(model_dir, 'model.pth')
        if os.path.exists(model_path):
            print('Loading checkpoint from {}...'.format(model_path))
            # load model and optimizer
            checkpoint = torch.load(os.path.join(model_dir, 'model.pth'), map_location='cpu')
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            start_epoch = checkpoint['epoch']
            print('Model loaded.')
        else:
            print('No checkpoint found. Train from scratch')
    
    # training
    metric_logger = MetricLogger()
    
    end = time.perf_counter()
    max_iters = epochs * len(dataloader)
    
    def normal_loss(pred, normal, conf):
        """
        :param pred: (B, 3, H, W)
        :param normal: (B, 3, H, W)
        :param conf: 1
        """
        dot_prod = (pred * normal).sum(dim=1)
        # weighted loss, (B, )
        batch_loss = ((1 - dot_prod) * conf[:, 0]).sum(1).sum(1)
        # normalize, to (B, )
        batch_loss /= conf[:, 0].sum(1).sum(1)
        return batch_loss.mean()

    def consistency_loss(pred, cloud, normal, conf):
        """
        :param pred: (B, 1, H, W)
        :param normal: (B, 3, H, W)
        :param cloud: (B, 3, H, W)
        :param conf: (B, 1, H, W)
        """
        B, _, _, _ = normal.size()
        normal = normal.detach()
        cloud = cloud.clone()
        cloud[:, 2:3, :, :] = pred
        # algorithm: use a kernel
        kernel = torch.ones((1, 1, 7, 7), device=pred.device)
        kernel = -kernel
        kernel[0, 0, 3, 3] = 48
    
        cloud_0 = cloud[:, 0:1]
        cloud_1 = cloud[:, 1:2]
        cloud_2 = cloud[:, 2:3]
        diff_0 = F.conv2d(cloud_0, kernel, padding=6, dilation=2)
        diff_1 = F.conv2d(cloud_1, kernel, padding=6, dilation=2)
        diff_2 = F.conv2d(cloud_2, kernel, padding=6, dilation=2)
        # (B, 3, H, W)
        diff = torch.cat((diff_0, diff_1, diff_2), dim=1)
        # normalize
        diff = F.normalize(diff, dim=1)
        # (B, 1, H, W)
        dot_prod = (diff * normal).sum(dim=1, keepdim=True)
        # weighted mean over image
        dot_prod = torch.abs(dot_prod.view(B, -1))
        conf = conf.view(B, -1)
        loss = (dot_prod * conf).sum(1) / conf.sum(1)
        # mean over batch
        return loss.mean()
    
    def criterion(depth_pred, normal_pred, depth, normal, cloud, conf):
        mse_loss = F.mse_loss(depth_pred, depth)
        consis_loss = consistency_loss(depth_pred, cloud, normal_pred, conf)
        norm_loss = normal_loss(normal_pred, normal, conf)
        consis_loss = torch.zeros_like(norm_loss)
        
        return mse_loss, mse_loss, mse_loss
        # return mse_loss, consis_loss, norm_loss
        # return norm_loss, norm_loss, norm_loss
    
    print('Start training')
    for epoch in range(start_epoch, epochs):
        # train
        model.train()
        for i, data in enumerate(dataloader):
            start = end
            i += 1
            data = [x.to(device) for x in data]
            image, depth, normal, conf, cloud = data
            depth_pred, normal_pred = model(image)
            mse_loss, consis_loss, norm_loss = criterion(depth_pred, normal_pred, depth, normal, cloud, conf)
            loss = mse_loss + consis_loss + norm_loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # bookkeeping
            end = time.perf_counter()
            metric_logger.update(loss=loss.item())
            metric_logger.update(mse_loss=mse_loss.item())
            metric_logger.update(norm_loss=norm_loss.item())
            metric_logger.update(consis_loss=consis_loss.item())
            metric_logger.update(batch_time=end-start)

            
            if i % print_every == 0:
                # Compute eta. global step: starting from 1
                global_step = epoch * len(dataloader) + i
                seconds = (max_iters - global_step) * metric_logger['batch_time'].global_avg
                eta = datetime.timedelta(seconds=int(seconds))
                # to display: eta, epoch, iteration, loss, batch_time
                display_dict = {
                    'eta': eta,
                    'epoch': epoch,
                    'iter': i,
                    'loss': metric_logger['loss'].median,
                    'batch_time': metric_logger['batch_time'].median
                }
                display_str = [
                    'eta: {eta}s',
                    'epoch: {epoch}',
                    'iter: {iter}',
                    'loss: {loss:.4f}',
                    'batch_time: {batch_time:.4f}s',
                ]
                print(', '.join(display_str).format(**display_dict))
                
                # tensorboard
                min_depth = depth[0].min()
                max_depth = depth[0].max() * 1.25
                depth = (depth[0] - min_depth) / (max_depth - min_depth)
                depth_pred = (depth_pred[0] - min_depth) / (max_depth - min_depth)
                depth_pred = torch.clamp(depth_pred, min=0.0, max=1.0)
                normal = (normal[0] + 1) / 2
                normal_pred = (normal_pred[0] + 1) / 2
                conf = conf[0]
                
                tb.add_scalar('train/loss', metric_logger['loss'].median, global_step)
                tb.add_scalar('train/mse_loss', metric_logger['mse_loss'].median, global_step)
                tb.add_scalar('train/consis_loss', metric_logger['consis_loss'].median, global_step)
                tb.add_scalar('train/norm_loss', metric_logger['norm_loss'].median, global_step)
                
                tb.add_image('train/depth', depth, global_step)
                tb.add_image('train/normal', normal, global_step)
                tb.add_image('train/depth_pred', depth_pred, global_step)
                tb.add_image('train/normal_pred', normal_pred, global_step)
                tb.add_image('train/conf', conf, global_step)
                tb.add_image('train/image', image[0], global_step)
                
        if (epoch) % val_every == 0 and epoch != 0:
            # validate after each epoch
            validate(dataloader, model, device, tb, epoch, 'train')
            validate(dataloader_test, model, device, tb, epoch, 'test')
        if (epoch) % save_every == 0 and epoch != 0:
            to_save = {
                'optimizer': optimizer.state_dict(),
                'model': model.state_dict(),
                'epoch': epoch,
            }
            torch.save(to_save, os.path.join(model_dir, 'model.pth'))
Пример #12
0
def train(cfg):
    torch.manual_seed(cfg.seed)
    np.random.seed(cfg.seed)
    torch.manual_seed(cfg.seed)
    torch.cuda.manual_seed(cfg.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Some info
    print('Experiment name:', cfg.exp_name)
    print('Model name:', cfg.model)
    print('Dataset:', cfg.dataset)
    print('Resume:', cfg.resume)
    if cfg.resume:
        print('Checkpoint:',
              cfg.resume_ckpt if cfg.resume_ckpt else 'last checkpoint')
    print('Using device:', cfg.device)
    if 'cuda' in cfg.device:
        print('Using parallel:', cfg.parallel)
    if cfg.parallel:
        print('Device ids:', cfg.device_ids)

    print('\nLoading data...')

    trainloader = get_dataloader(cfg, 'train')
    if cfg.val.ison or cfg.vis.ison:
        valset = get_dataset(cfg, 'val')
        valloader = get_dataloader(cfg, 'val')
    print('Data loaded.')

    print('Initializing model...')
    model = get_model(cfg)
    model = model.to(cfg.device)
    print('Model initialized.')
    model.train()

    optimizer = get_optimizer(cfg, model)

    # Checkpointer will print information.
    checkpointer = Checkpointer(os.path.join(cfg.checkpointdir, cfg.exp_name),
                                max_num=cfg.train.max_ckpt)

    start_epoch = 0
    start_iter = 0
    global_step = 0
    if cfg.resume:
        checkpoint = checkpointer.load(cfg.resume_ckpt, model, optimizer)
        if checkpoint:
            start_epoch = checkpoint['epoch']
            global_step = checkpoint['global_step'] + 1
    if cfg.parallel:
        model = nn.DataParallel(model, device_ids=cfg.device_ids)

    writer = SummaryWriter(log_dir=os.path.join(cfg.logdir, cfg.exp_name),
                           purge_step=global_step,
                           flush_secs=30)
    metric_logger = MetricLogger()
    vis_logger = get_vislogger(cfg)
    evaluator = get_evaluator(cfg)

    print('Start training')
    end_flag = False
    for epoch in range(start_epoch, cfg.train.max_epochs):
        if end_flag: break
        start = time.perf_counter()
        for i, data in enumerate(trainloader):
            end = time.perf_counter()
            data_time = end - start
            start = end

            imgs, *_ = [d.to(cfg.device) for d in data]
            model.train()
            loss, log = model(imgs, global_step)
            # If you are using DataParallel
            loss = loss.mean()
            optimizer.zero_grad()
            loss.backward()
            if cfg.train.clip_norm:
                clip_grad_norm_(model.parameters(), cfg.train.clip_norm)
            optimizer.step()

            end = time.perf_counter()
            batch_time = end - start

            metric_logger.update(data_time=data_time)
            metric_logger.update(batch_time=batch_time)
            metric_logger.update(loss=loss.item())

            if (global_step + 1) % cfg.train.print_every == 0:
                start = time.perf_counter()
                log.update(loss=metric_logger['loss'].median)
                vis_logger.model_log_vis(writer, log, global_step + 1)
                end = time.perf_counter()
                device_text = cfg.device_ids if cfg.parallel else cfg.device
                print(
                    'exp: {}, device: {}, epoch: {}, iter: {}/{}, global_step: {}, loss: {:.2f}, batch time: {:.4f}s, data time: {:.4f}s, log time: {:.4f}s'
                    .format(cfg.exp_name, device_text, epoch + 1, i + 1,
                            len(trainloader), global_step + 1,
                            metric_logger['loss'].median,
                            metric_logger['batch_time'].avg,
                            metric_logger['data_time'].avg, end - start))

            if (global_step + 1) % cfg.train.save_every == 0:
                start = time.perf_counter()
                checkpointer.save(model, optimizer, epoch, global_step)
                print('Saving checkpoint takes {:.4f}s.'.format(
                    time.perf_counter() - start))

            if (global_step + 1) % cfg.vis.vis_every == 0 and cfg.vis.ison:
                print('Doing visualization...')
                start = time.perf_counter()
                vis_logger.train_vis(model,
                                     valset,
                                     writer,
                                     global_step,
                                     cfg.vis.indices,
                                     cfg.device,
                                     cond_steps=cfg.vis.cond_steps,
                                     fg_sample=cfg.vis.fg_sample,
                                     bg_sample=cfg.vis.bg_sample,
                                     num_gen=cfg.vis.num_gen)
                print(
                    'Visualization takes {:.4f}s.'.format(time.perf_counter() -
                                                          start))

            if (global_step + 1) % cfg.val.val_every == 0 and cfg.val.ison:
                print('Doing evaluation...')
                start = time.perf_counter()
                evaluator.train_eval(
                    evaluator, os.path.join(cfg.evaldir,
                                            cfg.exp_name), cfg.val.metrics,
                    cfg.val.eval_types, cfg.val.intervals, cfg.val.cond_steps,
                    model, valset, valloader, cfg.device, writer, global_step,
                    [model, optimizer, epoch, global_step], checkpointer)
                print('Evaluation takes {:.4f}s.'.format(time.perf_counter() -
                                                         start))

            start = time.perf_counter()
            global_step += 1
            if global_step >= cfg.train.max_steps:
                end_flag = True
                break
Пример #13
0
def train(inputs, outputs, args, logger):
    """
     :param:
     - inputs: (list) 作为输入的tensor, 它是由preprocess.py处理得的
     - outputs: (tensor) 作为标注的tensor, 它是由preprocess.py处理得的
     - args: 一堆训练前规定好的参数
     - logger: 训练日志,可以把训练过程记录在./ckpt/log.txt
     :return: 训练结束
     """
    # 创建数据集
    # inputs[0] (50000,1024)即(data_num,max_input_len)
    # outputs (50000) 即(data_num)
    torch_dataset = Data.TensorDataset(inputs[0], inputs[1], inputs[2], outputs)
    loader = Data.DataLoader(dataset=torch_dataset, batch_size=args.batch_size, shuffle=True)
    logger.info('[1] Building model')
    # 查看运行训练脚本时,所用的设备,如果有cuda,就用cuda
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # 构造 model
    model = AlbertClassifierModel(num_topics=args.num_topics,
                                  out_channels=args.out_channels,
                                  max_input_len=args.max_input_len,
                                  kernel_size=args.kernel_size,
                                  dropout=args.dropout).to(device)
    model_kwargs = {k: getattr(args, k) for k in
                    {'num_topics', 'out_channels', 'max_input_len', 'kernel_size', 'dropout'}
                    }
    logger.info(model)
    # 优化器
    if args.optim == 'adam':
        optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    elif args.optim == 'adamw':
        optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    elif args.optim == 'sgd':
        optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.weight_decay)
    meters = MetricLogger(delimiter="  ")
    # BCEWithLogitsLoss是不需要sigmoid的二分类损失函数
    criterion = nn.CrossEntropyLoss()
    # scheduler,在schedule_step的时候,把学习率乘0.1,目前只在第一个step做了这个下降
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [args.schedule_step], gamma=0.1)
    logger.info('[2] Start training......')
    for epoch_num in range(args.max_epoch):
        # example_num:一个epoch需要训练多少个batch
        example_num = outputs.shape[0] // args.batch_size
        for batch_iter, (input_ids, segments_tensor, attention_mask, label) in enumerate(loader):
            progress = epoch_num + batch_iter / example_num
            optimizer.zero_grad()
            batch_size = args.batch_size
            # 正向传播
            pred = model(input_ids.to(device).view(batch_size, -1),
                         attention_mask.to(device).view(batch_size, -1))
            # 处理 label
            if label.shape[0] != args.batch_size:
                logger.info('last dummy batch')
                break
            label = label.view(args.batch_size)
            label = label.to(device)
            loss = criterion(pred, label)

            # 反向传播
            loss.backward()
            optimizer.step()
            meters.update(loss=loss)
            # 每过0.01个epoch记录一次loss
            if (batch_iter + 1) % (example_num // 100) == 0:
                logger.info(
                    meters.delimiter.join(
                        [
                            "progress: {prog:.2f}",
                            "{meters}",
                        ]
                    ).format(
                        prog=progress,
                        meters=str(meters),
                    )
                )
                # debug模式,先去valid
                if args.debug:
                    break
        # 验证这个epoch的效果
        precision, score = validate(model, device, args)
        logger.info("val")
        logger.info("precision")
        logger.info(precision)
        logger.info("official score")
        logger.info(score)
        save = {
            'kwargs': model_kwargs,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
        }

        scheduler.step()
        # 每个epoch保留一个ckpt
        torch.save(save,
                   os.path.join(args.save_dir, 'model_epoch%d_val%.3f.pt' % (epoch_num, score)))
Пример #14
0
# train
for epoch in range(num_epochs):
    metric_logger = MetricLogger(delimiter=' ')
    metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Epoch: [{}]'.format(epoch)

    model.train()
    for images, targets in metric_logger.log_every(train_data_loader, print_freq, header):
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())

        loss_dict_reduced = utils.reduce_dict(loss_dict)
        losses_reduced = sum(loss for loss in loss_dict_reduced.values())

        loss_value = losses_reduced.item()

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()
        lr_scheduler.step()

        metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])



# inference
Пример #15
0
def do_train(
    args,
    model,
    optimizer,
    scheduler,
    device
):
    # 日志
    logger = logging.getLogger("efficientnet.trainer")
    logger.info("Start training")
    meters = MetricLogger(delimiter="  ")
    # 数据
    traindir = os.path.join(args.data, 'train')
    normalize = torchvision.transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
    train_dataset = torchvision.datasets.ImageFolder(
        traindir,
        torchvision.transforms.Compose([
            torchvision.transforms.RandomResizedCrop(224),
            torchvision.transforms.RandomHorizontalFlip(),
            torchvision.transforms.ToTensor(),
            normalize,
        ]))

    train_sampler = make_data_sampler(train_dataset, True, args.distributed)
    data_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        sampler=train_sampler,
        num_workers=args.workers,
        pin_memory=True,
    )

    start_training_time = time.time()
    end = time.time()

    model.train()
    max_iter = len(data_loader)
    for epoch in range(args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)

        scheduler.step()
        for iteration, (images, targets) in enumerate(data_loader):
            data_time = time.time() - end

            images = images.to(device)
            targets = targets.to(device)

            loss = model(images, targets)
            # 显示loss
            loss_reduced = reduce_loss(loss)
            meters.update(loss=loss_reduced)
            
            optimizer.zero_grad()
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            optimizer.step()

            # 记录时间参数
            batch_time = time.time() - end
            end = time.time()
            meters.update(
                time=batch_time,
                data=data_time
            )
            eta_seconds = meters.time.global_avg * (max_iter * (args.epochs - epoch -1 ) - iteration)
            eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

            # 显示训练状态
            if iteration % args.print_freq == 0:
                logger.info(
                    meters.delimiter.join(
                        [
                            "eta: {eta}",
                            "epoch: {epoch}",
                            "iter: {iter}",
                            "{meters}",
                            "lr: {lr:.6f}",
                            "max mem (MB): {memory:.0f}",
                        ]
                    ).format(
                        eta=eta_string,
                        epoch=epoch,
                        iter=iteration,
                        meters=str(meters),
                        lr=optimizer.param_groups[0]["lr"],
                        memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
                    )
                )

        # 保存checkpoint
        if is_main_process() and ((epoch + 1) % args.ckpt_freq == 0 or (epoch + 1) == args.epochs):
            ckpt = {}
            ckpt['model'] = model.state_dict()
            ckpt['optimizer'] = optimizer.state_dict()
            save_file = os.path.join(args.output_dir, "efficientnet-epoch-{}.pth".format(epoch))
            torch.save(ckpt, save_file)
        
        # validate
        do_eval(args, model, args.distributed)

    # 总体训练时长
    total_training_time = time.time() - start_training_time
    total_time_str = str(datetime.timedelta(seconds=total_training_time))
    logger.info(
        "Total training time: {} ({:.4f} s / it)".format(
            total_time_str, total_training_time / (max_iter * args.epochs)
        )
    )
Пример #16
0
def train(cfg):

    print('Experiment name:', cfg.exp_name)
    print('Dataset:', cfg.dataset)
    print('Model name:', cfg.model)
    print('Resume:', cfg.resume)
    if cfg.resume:
        print('Checkpoint:',
              cfg.resume_ckpt if cfg.resume_ckpt else 'last checkpoint')
    print('Using device:', cfg.device)
    if 'cuda' in cfg.device:
        print('Using parallel:', cfg.parallel)
    if cfg.parallel:
        print('Device ids:', cfg.device_ids)

    print('Loading data')

    if cfg.exp_name == 'table':
        data_set = np.load('{}/train/all_set_train.npy'.format(
            cfg.dataset_roots.TABLE))
        data_size = len(data_set)
    else:
        trainloader = get_dataloader(cfg, 'train')
        data_size = len(trainloader)
    if cfg.train.eval_on:
        valset = get_dataset(cfg, 'val')
        # valloader = get_dataloader(cfg, 'val')
        evaluator = get_evaluator(cfg)
    model = get_model(cfg)
    model = model.to(cfg.device)
    checkpointer = Checkpointer(osp.join(cfg.checkpointdir, cfg.exp_name),
                                max_num=cfg.train.max_ckpt)
    model.train()

    optimizer_fg, optimizer_bg = get_optimizers(cfg, model)

    start_epoch = 0
    start_iter = 0
    global_step = 0
    if cfg.resume:
        checkpoint = checkpointer.load_last(cfg.resume_ckpt, model,
                                            optimizer_fg, optimizer_bg)
        if checkpoint:
            start_epoch = checkpoint['epoch']
            global_step = checkpoint['global_step'] + 1
    if cfg.parallel:
        model = nn.DataParallel(model, device_ids=cfg.device_ids)

    writer = SummaryWriter(log_dir=os.path.join(cfg.logdir, cfg.exp_name),
                           flush_secs=30,
                           purge_step=global_step)
    vis_logger = get_vislogger(cfg)
    metric_logger = MetricLogger()

    print('Start training')
    end_flag = False
    for epoch in range(start_epoch, cfg.train.max_epochs):
        if end_flag:
            break
        if cfg.exp_name == 'table':
            # creates indexes and shuffles them. So it can acces the data
            idx_set = np.arange(data_size)
            np.random.shuffle(idx_set)
            idx_set = np.split(idx_set, len(idx_set) / cfg.train.batch_size)
            data_to_enumerate = idx_set
        else:
            trainloader = get_dataloader(cfg, 'train')
            data_to_enumerate = trainloader
            data_size = len(trainloader)

        start = time.perf_counter()
        for i, enumerated_data in enumerate(data_to_enumerate):

            end = time.perf_counter()
            data_time = end - start
            start = end

            model.train()
            if cfg.exp_name == 'table':
                data_i = data_set[enumerated_data]
                data_i = torch.from_numpy(data_i).float().to(cfg.device)
                data_i /= 255
                data_i = data_i.permute([0, 3, 1, 2])
                imgs = data_i
            else:

                imgs = enumerated_data
                imgs = imgs.to(cfg.device)

            loss, log = model(imgs, global_step)
            # In case of using DataParallel
            loss = loss.mean()
            optimizer_fg.zero_grad()
            optimizer_bg.zero_grad()
            loss.backward()
            if cfg.train.clip_norm:
                clip_grad_norm_(model.parameters(), cfg.train.clip_norm)

            optimizer_fg.step()

            # if cfg.train.stop_bg == -1 or global_step < cfg.train.stop_bg:
            optimizer_bg.step()

            end = time.perf_counter()
            batch_time = end - start

            metric_logger.update(data_time=data_time)
            metric_logger.update(batch_time=batch_time)
            metric_logger.update(loss=loss.item())

            if (global_step) % cfg.train.print_every == 0:
                start = time.perf_counter()
                log.update({
                    'loss': metric_logger['loss'].median,
                })
                vis_logger.train_vis(writer, log, global_step, 'train')
                end = time.perf_counter()

                print(
                    'exp: {}, epoch: {}, iter: {}/{}, global_step: {}, loss: {:.2f}, batch time: {:.4f}s, data time: {:.4f}s, log time: {:.4f}s'
                    .format(cfg.exp_name, epoch + 1, i + 1, data_size,
                            global_step, metric_logger['loss'].median,
                            metric_logger['batch_time'].avg,
                            metric_logger['data_time'].avg, end - start))
            if (global_step) % cfg.train.create_image_every == 0:
                vis_logger.test_create_image(
                    log,
                    '../output/{}_img_{}.png'.format(cfg.dataset, global_step))
            if (global_step) % cfg.train.save_every == 0:
                start = time.perf_counter()
                checkpointer.save_last(model, optimizer_fg, optimizer_bg,
                                       epoch, global_step)
                print('Saving checkpoint takes {:.4f}s.'.format(
                    time.perf_counter() - start))

            if (global_step) % cfg.train.eval_every == 0 and cfg.train.eval_on:
                pass
                '''print('Validating...')
                start = time.perf_counter()
                checkpoint = [model, optimizer_fg, optimizer_bg, epoch, global_step]
                if cfg.exp_name == 'table':
                    evaluator.train_eval(model, None, None, writer, global_step, cfg.device, checkpoint, checkpointer)
                else:
                    evaluator.train_eval(model, valset, valset.bb_path, writer, global_step, cfg.device, checkpoint, checkpointer)

                print('Validation takes {:.4f}s.'.format(time.perf_counter() - start))'''

            start = time.perf_counter()
            global_step += 1
            if global_step > cfg.train.max_steps:
                end_flag = True
                break
Пример #17
0
def training(model, data_loader, optimizer, scheduler, checkpointer, device,
             checkpoint_period, arguments):
    logger = logging.getLogger("RetinaNet.trainer")
    logger.info("Start training")
    meters = MetricLogger(delimiter="  ")
    max_iter = len(data_loader)
    start_iter = arguments["iteration"]
    model.train()
    start_training_time = time.time()
    end = time.time()
    for iteration, (images, targets, _) in enumerate(data_loader, start_iter):
        data_time = time.time() - end
        iteration = iteration + 1
        arguments["iteration"] = iteration

        scheduler.step()

        images = images.to(device)
        targets = [target.to(device) for target in targets]

        loss_dict = model(images, targets)

        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = reduce_loss_dict(loss_dict)
        losses_reduced = sum(loss for loss in loss_dict_reduced.values())
        meters.update(loss=losses_reduced, **loss_dict_reduced)

        optimizer.zero_grad()

        optimizer.step()

        batch_time = time.time() - end
        end = time.time()
        meters.update(time=batch_time, data=data_time)

        eta_seconds = meters.time.global_avg * (max_iter - iteration)
        eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

        if iteration % 20 == 0 or iteration == max_iter:
            logger.info(
                meters.delimiter.join([
                    "eta: {eta}",
                    "iter: {iter}",
                    "{meters}",
                    "lr: {lr:.6f}",
                    "max mem: {memory:.0f}",
                ]).format(
                    eta=eta_string,
                    iter=iteration,
                    meters=str(meters),
                    lr=optimizer.param_groups[0]["lr"],
                    memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
                ))
        if iteration % checkpoint_period == 0:
            checkpointer.save("model_{:07d}".format(iteration), **arguments)
        if iteration == max_iter:
            checkpointer.save("model_final", **arguments)

    total_training_time = time.time() - start_training_time
    total_time_str = str(datetime.timedelta(seconds=total_training_time))
    logger.info("Total training time: {} ({:.4f} s / it)".format(
        total_time_str, total_training_time / (max_iter)))
Пример #18
0
def train(cfg):
    
    print('Experiment name:', cfg.exp_name)
    print('Dataset:', cfg.dataset)
    print('Model name:', cfg.model)
    print('Resume:', cfg.resume)
    if cfg.resume:
        print('Checkpoint:', cfg.resume_ckpt if cfg.resume_ckpt else 'last checkpoint')
    print('Using device:', cfg.device)
    if 'cuda' in cfg.device:
        print('Using parallel:', cfg.parallel)
    if cfg.parallel:
        print('Device ids:', cfg.device_ids)
    
    print('Loading data')

    trainloader = get_dataloader(cfg, 'train')
    if cfg.train.eval_on:
        valset = get_dataset(cfg, 'val')
        # valloader = get_dataloader(cfg, 'val')
        evaluator = get_evaluator(cfg)
    model = get_model(cfg)
    model = model.to(cfg.device)
    checkpointer = Checkpointer(osp.join(cfg.checkpointdir, cfg.exp_name), max_num=cfg.train.max_ckpt)
    model.train()

    optimizer_fg, optimizer_bg = get_optimizers(cfg, model)

    start_epoch = 0
    start_iter = 0
    global_step = 0
    if cfg.resume:
        checkpoint = checkpointer.load_last(cfg.resume_ckpt, model, optimizer_fg, optimizer_bg)
        if checkpoint:
            start_epoch = checkpoint['epoch']
            global_step = checkpoint['global_step'] + 1
    if cfg.parallel:
        model = nn.DataParallel(model, device_ids=cfg.device_ids)
    
    writer = SummaryWriter(log_dir=os.path.join(cfg.logdir, cfg.exp_name), flush_secs=30, purge_step=global_step)
    vis_logger = get_vislogger(cfg)
    metric_logger =  MetricLogger()

    print('Start training')
    end_flag = False
    for epoch in range(start_epoch, cfg.train.max_epochs):
        if end_flag:
            break
    
        start = time.perf_counter()
        for i, data in enumerate(trainloader):
        
            end = time.perf_counter()
            data_time = end - start
            start = end
        
            model.train()
            imgs = data
            imgs = imgs.to(cfg.device)
            loss, log = model(imgs, global_step)
            # In case of using DataParallel
            loss = loss.mean()
            optimizer_fg.zero_grad()
            optimizer_bg.zero_grad()
            loss.backward()
            if cfg.train.clip_norm:
                clip_grad_norm_(model.parameters(), cfg.train.clip_norm)
        
            optimizer_fg.step()
        
            # if cfg.train.stop_bg == -1 or global_step < cfg.train.stop_bg:
            optimizer_bg.step()
        
            end = time.perf_counter()
            batch_time = end - start
        
            metric_logger.update(data_time=data_time)
            metric_logger.update(batch_time=batch_time)
            metric_logger.update(loss=loss.item())
        
            if (global_step) % cfg.train.print_every == 0:
                start = time.perf_counter()
                log.update({
                    'loss': metric_logger['loss'].median,
                })
                vis_logger.train_vis(writer, log, global_step, 'train')
                end = time.perf_counter()
            
                print(
                    'exp: {}, epoch: {}, iter: {}/{}, global_step: {}, loss: {:.2f}, batch time: {:.4f}s, data time: {:.4f}s, log time: {:.4f}s'.format(
                        cfg.exp_name, epoch + 1, i + 1, len(trainloader), global_step, metric_logger['loss'].median,
                        metric_logger['batch_time'].avg, metric_logger['data_time'].avg, end - start))
                
            if (global_step) % cfg.train.save_every == 0:
                start = time.perf_counter()
                checkpointer.save_last(model, optimizer_fg, optimizer_bg, epoch, global_step)
                print('Saving checkpoint takes {:.4f}s.'.format(time.perf_counter() - start))
        
            if (global_step) % cfg.train.eval_every == 0 and cfg.train.eval_on:
                print('Validating...')
                start = time.perf_counter()
                checkpoint = [model, optimizer_fg, optimizer_bg, epoch, global_step]
                evaluator.train_eval(model, valset, valset.bb_path, writer, global_step, cfg.device, checkpoint, checkpointer)
                print('Validation takes {:.4f}s.'.format(time.perf_counter() - start))
        
            start = time.perf_counter()
            global_step += 1
            if global_step > cfg.train.max_steps:
                end_flag = True
                break