def train(train_dataloader, model, epoch, loss_func,
          optimizer, scheduler, training_stats, val_dataloader=None, val_err=[], ignore_step=-1):
    """
    Train the model in steps
    """
    model.train()
    epoch_steps = math.ceil(len(train_dataloader) / cfg.TRAIN.BATCHSIZE)
    base_steps = epoch_steps * epoch + ignore_step if ignore_step != -1 else epoch_steps * epoch
    for i, data in enumerate(train_dataloader):
        if ignore_step != -1 and i > epoch_steps - ignore_step:
            return
        scheduler.step()  # decay lr every iteration
        training_stats.IterTic()
        out = model(data)
        losses = loss_func.criterion(out['b_fake_softmax'], out['b_fake_logit'], data, epoch)
        optimizer.optim(losses)
        step = base_steps + i + 1
        training_stats.UpdateIterStats(losses)
        training_stats.IterToc()
        training_stats.LogIterStats(step, epoch, optimizer.optimizer, val_err[0])

        # validate the model
        if step % cfg.TRAIN.VAL_STEP == 0 and step != 0 and val_dataloader is not None:
            model.eval()
            val_err[0] = val(val_dataloader, model)
            # training mode
            model.train()

        # save checkpoint
        if step % cfg.TRAIN.SNAPSHOT_ITERS == 0 and step != 0:
            save_ckpt(train_args, step, epoch, model, optimizer.optimizer, scheduler, val_err[0])
示例#2
0
def train(train_dataloader, model, epoch, loss_func,
          optimizer, scheduler, training_stats, val_dataloader=None, val_err=[], ignore_step=-1):
    """
    Train the model in steps
    """
    model.train()
    epoch_steps = math.ceil(len(train_dataloader) / cfg.TRAIN.BATCHSIZE)
    base_steps = epoch_steps * epoch + ignore_step if ignore_step != -1 else epoch_steps * epoch
    for i, data in enumerate(train_dataloader):
        print("step:", i)

        if ignore_step != -1 and i > epoch_steps - ignore_step:
            return
        scheduler.step()  # decay lr every iteration
        training_stats.IterTic()
        out = model(data)

        # image=data['A'][0]
        # img=torchvision.transforms.ToPILImage()(image)
        # img.show()

        # gt=data['B'][0]
        # gt=torchvision.transforms.ToPILImage()(gt)
        # gt.show()
        #
        # depth=bins_to_depth(out['b_fake_softmax'])[0]
        # depth_img=torchvision.transforms.ToPILImage()(depth)
        # depth_img.show()

        if train_args.refine==True:
            losses = loss_func.criterion(out['refined_depth'], data)
        else:
            losses = loss_func.criterion(out['b_fake_softmax'], out['b_fake_logit'], data, epoch)
        optimizer.optim(losses)
        step = base_steps + i + 1
        training_stats.UpdateIterStats(losses)
        training_stats.IterToc()
        training_stats.LogIterStats(step, epoch, optimizer.optimizer, val_err[0])

        # validate the model
        if step % cfg.TRAIN.VAL_STEP == 0 and step != 0 and val_dataloader is not None:
            model.eval()
            val_err[0] = val(val_dataloader, model)
            # training mode
            model.train()

        # save checkpoint
        if step % cfg.TRAIN.SNAPSHOT_ITERS == 0 and step != 0:
            save_ckpt(train_args, step, epoch, model, optimizer.optimizer, scheduler, val_err[0])

        break
示例#3
0
                data = next(dataloader_iterator)
            except:
                dataloader_iterator = iter(train_dataloader)
                data = next(dataloader_iterator)

            training_stats.IterTic()
            out = model(data)
            losses = loss_func.criterion(out['b_fake_softmax'], out['b_fake_logit'], data)
            optimizer.optim(losses)
            training_stats.UpdateIterStats(losses)
            training_stats.IterToc()
            training_stats.LogIterStats(step, 0, optimizer.optimizer, val_err[0])
            # validate the model
            if (step+1) % cfg.TRAIN.VAL_STEP == 0  and val_dataloader is not None and step != 0:
                model.eval()
                val_err[0] = val(val_dataloader, model)
                # training mode
                model.train()
            # save checkpoint
            if step % cfg.TRAIN.SNAPSHOT_ITERS == 0 and step != 0:
                save_ckpt(train_args, step, epoch, model, optimizer.optimizer, scheduler, val_err[0])


    except (RuntimeError, KeyboardInterrupt):
        logger.info('Save ckpt on exception ...')
        stack_trace = traceback.format_exc()
        print(stack_trace)
    finally:
        if train_args.use_tfboard:
            tblogger.close()
示例#4
0
def do_train(train_dataloader,
             val_dataloader,
             train_args,
             model,
             save_to_disk,
             scheduler,
             optimizer,
             val_err,
             logger,
             tblogger=None):
    print(cfg.TRAIN.BASE_LR)
    # training status for logging
    if save_to_disk:
        training_stats = TrainingStats(
            train_args, cfg.TRAIN.LOG_INTERVAL,
            tblogger if train_args.use_tfboard else None)

    dataloader_iterator = iter(train_dataloader)
    start_step = train_args.start_step
    total_iters = cfg.TRAIN.MAX_ITER
    train_datasize = train_dataloader.batch_sampler.sampler.total_sampled_size

    pytorch_1_1_0_or_later = is_pytorch_1_1_0_or_later()
    tmp_i = 0
    try:
        for step in range(start_step, total_iters):

            if step % train_args.sample_ratio_steps == 0 and step != 0:
                sample_ratio = increase_sample_ratio_steps(
                    step,
                    base_ratio=train_args.sample_start_ratio,
                    step_size=train_args.sample_ratio_steps)
                train_dataloader, curr_sample_size = MultipleDataLoaderDistributed(
                    train_args, sample_ratio=sample_ratio)
                dataloader_iterator = iter(train_dataloader)
                logger.info(
                    'Sample ratio: %02f, current sampled datasize: %d' %
                    (sample_ratio, np.sum(curr_sample_size)))

            epoch = int(step * train_args.batchsize * train_args.world_size /
                        train_datasize)
            if save_to_disk:
                training_stats.IterTic()

            # get the next data batch
            try:
                data = next(dataloader_iterator)
            except:
                dataloader_iterator = iter(train_dataloader)
                data = next(dataloader_iterator)

            out = model(data)
            losses_dict = out['losses']
            optimizer.optim(losses_dict)

            #################Check data loading######################
            # tmp_path_base = '/home/yvan/DeepLearning/Depth/DiverseDepth-github/DiverseDepth/datasets/x/'
            # rgb = data['A'][1, ...].permute(1, 2, 0).squeeze()
            # rgb =rgb * torch.tensor(cfg.DATASET.RGB_PIXEL_VARS)[None, None, :] + torch.tensor(cfg.DATASET.RGB_PIXEL_MEANS)[None, None, :]
            # rgb = rgb * 255
            # rgb = rgb.cpu().numpy().astype(np.uint8)
            # depth = (data['B'][1, ...].squeeze().cpu().numpy()*1000)
            # depth[depth<0] = 0
            # depth = depth.astype(np.uint16)
            # plt.imsave(tmp_path_base+'%04d_r.jpg' % tmp_i, rgb)
            # plt.imsave(tmp_path_base+'%04d_d.png' % tmp_i, depth, cmap='rainbow')
            # tmp_i +=1
            #########################################################

            # reduce losses over all GPUs for logging purposes
            loss_dict_reduced = reduce_loss_dict(losses_dict)

            scheduler.step()
            if save_to_disk:
                training_stats.UpdateIterStats(loss_dict_reduced)
                training_stats.IterToc()
                training_stats.LogIterStats(step, epoch, optimizer.optimizer,
                                            val_err[0])

            # validate the model
            if step % cfg.TRAIN.VAL_STEP == 0 and val_dataloader is not None and step != 0:
                model.eval()
                val_err[0] = val(val_dataloader, model)
                # training mode
                model.train()
            # save checkpoint
            if step % cfg.TRAIN.SNAPSHOT_ITERS == 0 and step != 0 and save_to_disk:
                save_ckpt(train_args, step, epoch, model, optimizer.optimizer,
                          scheduler, val_err[0])

    except (RuntimeError, KeyboardInterrupt):
        stack_trace = traceback.format_exc()
        print(stack_trace)
    finally:
        if train_args.use_tfboard and main_process(train_args):
            tblogger.close()
def do_train(train_dataloader,
             val_dataloader,
             train_args,
             model,
             save_to_disk,
             scheduler,
             optimizer,
             val_err,
             logger,
             tblogger=None):

    # training status for logging
    if save_to_disk:
        training_stats = TrainingStats(
            train_args, cfg.TRAIN.LOG_INTERVAL,
            tblogger if train_args.use_tfboard else None)

    dataloader_iterator = iter(train_dataloader)
    start_step = train_args.start_step
    total_iters = cfg.TRAIN.MAX_ITER
    train_datasize = len(train_dataloader)

    pytorch_1_1_0_or_later = is_pytorch_1_1_0_or_later()

    try:
        for step in range(start_step, total_iters):

            if step % train_args.sample_ratio_steps == 0 and step != 0:
                sample_ratio = increase_sample_ratio_steps(
                    step,
                    base_ratio=train_args.sample_start_ratio,
                    step_size=train_args.sample_ratio_steps)
                train_dataloader = MultipleDataLoaderDistributed(
                    train_args, sample_ratio=sample_ratio)
                dataloader_iterator = iter(train_dataloader)
                logger.info(
                    'Sample ratio: %02f, current sampled datasize: %d' %
                    (sample_ratio, np.sum(train_dataloader.curr_sample_size)))

            # in pytorch >= 1.1.0, scheduler.step() should be run after optimizer.step()
            if not pytorch_1_1_0_or_later:
                scheduler.step()

            epoch = int(step * train_args.batchsize / train_datasize)
            if save_to_disk:
                training_stats.IterTic()

            # get the next data batch
            try:
                data = next(dataloader_iterator)
            except:
                dataloader_iterator = iter(train_dataloader)
                data = next(dataloader_iterator)

            out = model(data)
            losses_dict = out['losses']
            optimizer.optim(losses_dict)

            # reduce losses over all GPUs for logging purposes
            loss_dict_reduced = reduce_loss_dict(losses_dict)

            if pytorch_1_1_0_or_later:
                scheduler.step()
            if save_to_disk:
                training_stats.UpdateIterStats(loss_dict_reduced)
                training_stats.IterToc()
                training_stats.LogIterStats(step, epoch, optimizer.optimizer,
                                            val_err[0])

            # validate the model
            if step % cfg.TRAIN.VAL_STEP == 0 and val_dataloader is not None and step != 0:
                model.eval()
                val_err[0] = val(val_dataloader, model)
                # training mode
                model.train()
            # save checkpoint
            if step % cfg.TRAIN.SNAPSHOT_ITERS == 0 and step != 0 and save_to_disk:
                save_ckpt(train_args, step, epoch, model, optimizer.optimizer,
                          scheduler, val_err[0])

    except (RuntimeError, KeyboardInterrupt):
        stack_trace = traceback.format_exc()
        print(stack_trace)
    finally:
        if train_args.use_tfboard and get_rank() == 0:
            tblogger.close()