def main():
    global args, best_result, output_directory

    # set random seed
    torch.manual_seed(args.manual_seed)
    torch.cuda.manual_seed(args.manual_seed)
    np.random.seed(args.manual_seed)

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        args.batch_size = args.batch_size * torch.cuda.device_count()
    else:
        print("Let's use GPU ", torch.cuda.current_device())

    train_loader, val_loader = create_loader(args)

    if args.mode == 'test':
        if args.resume:
            assert os.path.isfile(args.resume), \
                "=> no checkpoint found at '{}'".format(args.resume)
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)

            epoch = checkpoint['epoch']
            best_result = checkpoint['best_result']

            # solve 'out of memory'
            model = checkpoint['model']

            print("=> loaded checkpoint (epoch {})".format(
                checkpoint['epoch']))

            # clear memory
            del checkpoint
            # del model_dict
            torch.cuda.empty_cache()
        else:
            print("no trained model to test.")

        result, img_merge = validate(args,
                                     val_loader,
                                     model,
                                     epoch,
                                     logger=None)

        print(
            'Test Result: mean iou={result.mean_iou:.3f}, mean acc={result.mean_acc:.3f}.'
            .format(result=result))
    elif args.mode == 'train':
        if args.resume:
            assert os.path.isfile(args.resume), \
                "=> no checkpoint found at '{}'".format(args.resume)
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)

            start_iter = checkpoint['epoch'] + 1
            best_result = checkpoint['best_result']
            optimizer = checkpoint['optimizer']

            # solve 'out of memory'
            model = checkpoint['model']

            print("=> loaded checkpoint (epoch {})".format(
                checkpoint['epoch']))

            # clear memory
            del checkpoint
            # del model_dict
            torch.cuda.empty_cache()
        else:
            print("=> creating Model")
            model = get_models(args)
            print("=> model created.")
            start_iter = 1

            # different modules have different learning rate
            train_params = [{
                'params': model.get_1x_lr_params(),
                'lr': args.lr
            }, {
                'params': model.get_10x_lr_params(),
                'lr': args.lr * 10
            }]

            print(train_params)

            optimizer = torch.optim.SGD(train_params,
                                        lr=args.lr,
                                        momentum=args.momentum,
                                        weight_decay=args.weight_decay)

            # You can use DataParallel() whether you use Multi-GPUs or not
            model = nn.DataParallel(model).cuda()

        scheduler = PolynomialLR(optimizer=optimizer,
                                 step_size=args.lr_decay,
                                 iter_max=args.max_iter,
                                 power=args.power)

        # loss function
        criterion = criteria._CrossEntropyLoss2d(size_average=True,
                                                 batch_average=True)

        # create directory path
        output_directory = utils.get_output_directory(args)
        if not os.path.exists(output_directory):
            os.makedirs(output_directory)
        best_txt = os.path.join(output_directory, 'best.txt')
        config_txt = os.path.join(output_directory, 'config.txt')

        # write training parameters to config file
        if not os.path.exists(config_txt):
            with open(config_txt, 'w') as txtfile:
                args_ = vars(args)
                args_str = ''
                for k, v in args_.items():
                    args_str = args_str + str(k) + ':' + str(v) + ',\t\n'
                txtfile.write(args_str)

        # create log
        log_path = os.path.join(
            output_directory, 'logs',
            datetime.now().strftime('%b%d_%H-%M-%S') + '_' +
            socket.gethostname())
        if os.path.isdir(log_path):
            shutil.rmtree(log_path)
        os.makedirs(log_path)
        logger = SummaryWriter(log_path)

        # train
        model.train()
        if args.freeze:
            model.module.freeze_backbone_bn()
        output_directory = utils.get_output_directory(args, check=True)

        average_meter = AverageMeter()

        for it in tqdm(range(start_iter, args.max_iter + 1),
                       total=args.max_iter,
                       leave=False,
                       dynamic_ncols=True):
            # for it in range(1, args.max_iter + 1):
            # Clear gradients (ready to accumulate)
            optimizer.zero_grad()

            loss = 0

            data_time = 0
            gpu_time = 0

            for _ in range(args.iter_size):
                end = time.time()
                try:
                    samples = next(loader_iter)
                except:
                    loader_iter = iter(train_loader)
                    samples = next(loader_iter)

                input = samples['image'].cuda()
                target = samples['label'].cuda()

                torch.cuda.synchronize()
                data_time_ = time.time()
                data_time += data_time_ - end

                with torch.autograd.detect_anomaly():
                    preds = model(input)  # @wx 注意输出

                    # print('#train preds size:', len(preds))
                    # print('#train preds[0] size:', preds[0].size())
                    iter_loss = 0
                    if args.msc:
                        for pred in preds:
                            # Resize labels for {100%, 75%, 50%, Max} logits
                            target_ = utils.resize_labels(
                                target,
                                shape=(pred.size()[-2], pred.size()[-1]))
                            # print('#train pred size:', pred.size())
                            iter_loss += criterion(pred, target_)
                    else:
                        pred = preds
                        target_ = utils.resize_labels(target,
                                                      shape=(pred.size()[-2],
                                                             pred.size()[-1]))
                        # print('#train pred size:', pred.size())
                        # print('#train target size:', target.size())
                        iter_loss += criterion(pred, target_)

                    # Backpropagate (just compute gradients wrt the loss)
                    iter_loss /= args.iter_size
                    iter_loss.backward()

                    loss += float(iter_loss)

                gpu_time += time.time() - data_time_

            torch.cuda.synchronize()

            # Update weights with accumulated gradients
            optimizer.step()

            # Update learning rate
            scheduler.step(epoch=it)

            # measure accuracy and record loss
            result = Result()
            pred = F.softmax(pred, dim=1)

            result.evaluate(pred.data.cpu().numpy(),
                            target.data.cpu().numpy(),
                            n_class=21)
            average_meter.update(result, gpu_time, data_time, input.size(0))

            if it % args.print_freq == 0:
                print('=> output: {}'.format(output_directory))
                print('Train Iter: [{0}/{1}]\t'
                      't_Data={data_time:.3f}({average.data_time:.3f}) '
                      't_GPU={gpu_time:.3f}({average.gpu_time:.3f})\n\t'
                      'Loss={Loss:.5f} '
                      'MeanAcc={result.mean_acc:.3f}({average.mean_acc:.3f}) '
                      'MIOU={result.mean_iou:.3f}({average.mean_iou:.3f}) '.
                      format(it,
                             args.max_iter,
                             data_time=data_time,
                             gpu_time=gpu_time,
                             Loss=loss,
                             result=result,
                             average=average_meter.average()))
                logger.add_scalar('Train/Loss', loss, it)
                logger.add_scalar('Train/mean_acc', result.mean_iou, it)
                logger.add_scalar('Train/mean_iou', result.mean_acc, it)

                for i, param_group in enumerate(optimizer.param_groups):
                    old_lr = float(param_group['lr'])
                    logger.add_scalar('Lr/lr_' + str(i), old_lr, it)

            if it % args.iter_save == 0:
                resu1t, img_merge = validate(args,
                                             val_loader,
                                             model,
                                             epoch=it,
                                             logger=logger)

                # remember best rmse and save checkpoint
                is_best = result.mean_iou < best_result.mean_iou
                if is_best:
                    best_result = result
                    with open(best_txt, 'w') as txtfile:
                        txtfile.write(
                            "Iter={}, mean_iou={:.3f}, mean_acc={:.3f}"
                            "t_gpu={:.4f}".format(it, result.mean_iou,
                                                  result.mean_acc,
                                                  result.gpu_time))
                    if img_merge is not None:
                        img_filename = output_directory + '/comparison_best.png'
                        utils.save_image(img_merge, img_filename)

                # save checkpoint for each epoch
                utils.save_checkpoint(
                    {
                        'args': args,
                        'epoch': it,
                        'model': model,
                        'best_result': best_result,
                        'optimizer': optimizer,
                    }, is_best, it, output_directory)

                # change to train mode
                model.train()
                if args.freeze:
                    model.module.freeze_backbone_bn()

        logger.close()
    else:
        print('no mode named as ', args.mode)
        exit(-1)
class trainer(object):
    def __init__(self, opt, model, optimizer, start_iter, best_result=None):
        self.opt = opt
        self.model = model.cuda()
        self.optimizer = optimizer
        self.scheduler = get_schedular(optimizer, self.opt)
        self.criterion = get_criteria(self.opt)

        self.criterion = get_criteria(self.opt)

        self.output_directory = utils.get_save_path(self.opt)
        self.best_txt = os.path.join(self.output_directory, 'best.txt')
        self.logger = utils.get_logger(self.output_directory)
        opt.write_config(self.output_directory)

        self.st_iter, self.ed_iter = start_iter, self.opt.max_iter

        # data loader
        from dataloaders import create_loader
        self.train_loader = create_loader(self.opt, mode='train')
        self.eval_loader = create_loader(self.opt, mode='val')

        if best_result:
            self.best_result = best_result
        else:
            self.best_result = Result()
            self.best_result.set_to_worst()

        # train parameters
        self.iter_save = len(self.train_loader)
        # self.iter_save = len(self.train_loader)
        self.train_meter = AverageMeter()
        self.eval_meter = AverageMeter()
        self.metric = self.best_result.absrel
        self.result = Result()

    def train_iter(self, it):
        # Clear gradients (ready to accumulate)
        self.optimizer.zero_grad()

        end = time.time()

        try:
            input, target = next(loader_iter)
        except:
            loader_iter = iter(self.train_loader)
            input, target = next(loader_iter)

        input, target = input.cuda(), target.cuda()
        data_time = time.time() - end

        # compute pred
        end = time.time()
        pred = self.model(input)  # @wx 注意输出

        loss = self.criterion(pred, target)
        loss.backward()  # compute gradient and do SGD step
        self.optimizer.step()

        gpu_time = time.time() - end

        # measure accuracy and record loss in each GPU
        self.result.set_to_worst()
        self.result.evaluate(pred[0], target, loss.item())
        self.train_meter.update(self.result, gpu_time, data_time,
                                input.size(0))

        avg = self.train_meter.average()
        if it % self.opt.print_freq == 0:
            print('=> output: {}'.format(self.output_directory))
            print('Train Iter: [{0}/{1}]\t'
                  't_Data={data_time:.3f}({average.data_time:.3f}) '
                  't_GPU={gpu_time:.3f}({average.gpu_time:.3f})\n\t'
                  'Loss={Loss:.5f}({average.loss:.5f}) '
                  'RMSE={result.rmse:.2f}({average.rmse:.2f}) '
                  'REL={result.absrel:.2f}({average.absrel:.2f}) '
                  'Log10={result.lg10:.3f}({average.lg10:.3f}) '
                  'Delta1={result.delta1:.3f}({average.delta1:.3f}) '
                  'Delta2={result.delta2:.3f}({average.delta2:.3f}) '
                  'Delta3={result.delta3:.3f}({average.delta3:.3f})'.format(
                      it,
                      self.opt.max_iter,
                      data_time=data_time,
                      gpu_time=gpu_time,
                      Loss=loss.item(),
                      result=self.result,
                      average=avg))

            self.logger.add_scalar('Train/Loss', avg.loss, it)
            self.logger.add_scalar('Train/RMSE', avg.rmse, it)
            self.logger.add_scalar('Train/rel', avg.absrel, it)
            self.logger.add_scalar('Train/Log10', avg.lg10, it)
            self.logger.add_scalar('Train/Delta1', avg.delta1, it)
            self.logger.add_scalar('Train/Delta2', avg.delta2, it)
            self.logger.add_scalar('Train/Delta3', avg.delta3, it)

    def eval(self, it):

        skip = len(self.eval_loader) // 9  # save images every skip iters
        self.eval_meter.reset()

        for i, (input, target) in enumerate(self.eval_loader):

            end = time.time()
            input, target = input.cuda(), target.cuda()

            data_time = time.time() - end

            # compute output
            end = time.time()
            with torch.no_grad():
                pred = self.model(input)

            gpu_time = time.time() - end

            # measure accuracy and record loss
            # print(input.size(0))

            self.result.set_to_worst()
            self.result.evaluate(pred[0], target)
            self.eval_meter.update(self.result, gpu_time, data_time,
                                   input.size(0))

            if i % skip == 0:
                pred = pred[0]

                # save 8 images for visualization
                h, w = target.size(2), target.size(3)
                if h != pred.size(2) or w != pred.size(3):
                    pred = F.interpolate(input=pred,
                                         size=(h, w),
                                         mode='bilinear',
                                         align_corners=True)

                data = input[0]
                target = target[0]
                pred = pred[0]

            if self.opt.modality == 'd':
                img_merge = None
            else:
                if self.opt.modality == 'rgb':
                    rgb = data
                elif self.opt.modality == 'rgbd':
                    rgb = data[:3, :, :]
                    depth = data[3:, :, :]

                if i == 0:
                    if self.opt.modality == 'rgbd':
                        img_merge = utils.merge_into_row_with_gt(
                            rgb, depth, target, pred)
                    else:
                        img_merge = utils.merge_into_row(rgb, target, pred)

                elif (i < 8 * skip) and (i % skip == 0):
                    if self.opt.modality == 'rgbd':
                        row = utils.merge_into_row_with_gt(
                            rgb, depth, target, pred)
                    else:
                        row = utils.merge_into_row(rgb, target, pred)
                    img_merge = utils.add_row(img_merge, row)
                elif i == 8 * skip:
                    filename = self.output_directory + '/comparison_' + str(
                        it) + '.png'
                    utils.save_image(img_merge, filename)

            if (i + 1) % self.opt.print_freq == 0:
                print(
                    'Test: [{0}/{1}]\t'
                    't_GPU={gpu_time:.3f}({average.gpu_time:.3f})\n\t'
                    'RMSE={result.rmse:.2f}({average.rmse:.2f}) '
                    'REL={result.absrel:.2f}({average.absrel:.2f}) '
                    'Log10={result.lg10:.3f}({average.lg10:.3f}) '
                    'Delta1={result.delta1:.3f}({average.delta1:.3f}) '
                    'Delta2={result.delta2:.3f}({average.delta2:.3f}) '
                    'Delta3={result.delta3:.3f}({average.delta3:.3f}) '.format(
                        i + 1,
                        len(self.eval_loader),
                        gpu_time=gpu_time,
                        result=self.result,
                        average=self.eval_meter.average()))

        avg = self.eval_meter.average()

        self.logger.add_scalar('Test/RMSE', avg.rmse, it)
        self.logger.add_scalar('Test/rel', avg.absrel, it)
        self.logger.add_scalar('Test/Log10', avg.lg10, it)
        self.logger.add_scalar('Test/Delta1', avg.delta1, it)
        self.logger.add_scalar('Test/Delta2', avg.delta2, it)
        self.logger.add_scalar('Test/Delta3', avg.delta3, it)

        print('\n*\n'
              'RMSE={average.rmse:.3f}\n'
              'Rel={average.absrel:.3f}\n'
              'Log10={average.lg10:.3f}\n'
              'Delta1={average.delta1:.3f}\n'
              'Delta2={average.delta2:.3f}\n'
              'Delta3={average.delta3:.3f}\n'
              't_GPU={time:.3f}\n'.format(average=avg, time=avg.gpu_time))

    def train_eval(self):

        for it in tqdm(range(self.st_iter, self.ed_iter + 1),
                       total=self.ed_iter - self.st_iter + 1,
                       leave=False,
                       dynamic_ncols=True):
            self.model.train()
            self.train_iter(it)

            # save the change of learning_rate
            for i, param_group in enumerate(self.optimizer.param_groups):
                old_lr = float(param_group['lr'])
                self.logger.add_scalar('Lr/lr_' + str(i), old_lr, it)

            if it % self.iter_save == 0:
                self.model.eval()
                self.eval(it)

                self.metric = self.eval_meter.average().absrel
                train_avg = self.train_meter.average()
                eval_avg = self.eval_meter.average()

                self.logger.add_scalars('TrainVal/rmse', {
                    'train_rmse': train_avg.rmse,
                    'test_rmse': eval_avg.rmse
                }, it)
                self.logger.add_scalars('TrainVal/rel', {
                    'train_rel': train_avg.absrel,
                    'test_rmse': eval_avg.absrel
                }, it)
                self.logger.add_scalars('TrainVal/lg10', {
                    'train_lg10': train_avg.lg10,
                    'test_rmse': eval_avg.lg10
                }, it)
                self.logger.add_scalars('TrainVal/Delta1', {
                    'train_d1': train_avg.delta1,
                    'test_d1': eval_avg.delta1
                }, it)
                self.logger.add_scalars('TrainVal/Delta2', {
                    'train_d2': train_avg.delta2,
                    'test_d2': eval_avg.delta2
                }, it)
                self.logger.add_scalars('TrainVal/Delta3', {
                    'train_d3': train_avg.delta3,
                    'test_d3': eval_avg.delta3
                }, it)

                self.train_meter.reset()

                # remember best rmse and save checkpoint
                is_best = eval_avg.absrel < self.best_result.absrel
                if is_best:
                    self.best_result = eval_avg
                    with open(self.best_txt, 'w') as txtfile:
                        txtfile.write(
                            "Iter={}, rmse={:.3f}, rel={:.3f}, log10={:.3f}, d1={:.3f}, d2={:.3f}, dd31={:.3f}, "
                            "t_gpu={:.4f}".format(
                                it, eval_avg.rmse, eval_avg.absrel,
                                eval_avg.lg10, eval_avg.delta1,
                                eval_avg.delta2, eval_avg.delta3,
                                eval_avg.gpu_time))

                # save checkpoint for each epoch
                utils.save_checkpoint(
                    {
                        'args': self.opt,
                        'epoch': it,
                        'state_dict': self.model.state_dict(),
                        'best_result': self.best_result,
                        'optimizer': self.optimizer,
                    }, is_best, it, self.output_directory)

            # Update learning rate
            do_schedule(self.opt,
                        self.scheduler,
                        it=it,
                        len=self.iter_save,
                        metrics=self.metric)

        self.logger.close()
def validate(args, val_loader, model, epoch, logger):
    average_meter = AverageMeter()
    model.eval()  # switch to train mode

    output_directory = utils.get_output_directory(args, check=True)
    skip = len(val_loader) // 9  # save images every skip iters

    if args.crf:
        ITER_MAX = 10
        POS_W = 3
        POS_XY_STD = 1
        BI_W = 4
        BI_XY_STD = 67
        BI_RGB_STD = 3

        postprocessor = DenseCRF(
            iter_max=ITER_MAX,
            pos_xy_std=POS_XY_STD,
            pos_w=POS_W,
            bi_xy_std=BI_XY_STD,
            bi_rgb_std=BI_RGB_STD,
            bi_w=BI_W,
        )

    end = time.time()

    for i, samples in enumerate(val_loader):

        input = samples['image']
        target = samples['label']

        # itr_count += 1
        input, target = input.cuda(), target.cuda()
        # print('input size  = ', input.size())
        # print('target size = ', target.size())
        torch.cuda.synchronize()
        data_time = time.time() - end

        # compute pred
        end = time.time()

        with torch.no_grad():
            pred = model(input)  # @wx 注意输出

        torch.cuda.synchronize()
        gpu_time = time.time() - end

        # measure accuracy and record loss
        result = Result()

        pred = F.softmax(pred, 1)

        if pred.size() != target.size():
            pred = F.interpolate(pred,
                                 size=(target.size()[-2], target.size()[-1]),
                                 mode='bilinear',
                                 align_corners=True)

        pred = pred.data.cpu().numpy()
        target = target.data.cpu().numpy()

        # Post Processing
        if args.crf:
            images = input.data.cpu().numpy().astype(np.uint8).transpose(
                0, 2, 3, 1)
            pred = joblib.Parallel(n_jobs=-1)([
                joblib.delayed(postprocessor)(*pair)
                for pair in zip(images, pred)
            ])

        result.evaluate(pred, target, n_class=21)
        average_meter.update(result, gpu_time, data_time, input.size(0))
        end = time.time()

        # save 8 images for visualization
        rgb = input.data.cpu().numpy()[0]
        target = target[0]
        pred = np.argmax(pred, axis=1)
        pred = pred[0]

        if i == 0:
            img_merge = utils.merge_into_row(rgb, target, pred)
        elif (i < 8 * skip) and (i % skip == 0):
            row = utils.merge_into_row(rgb, target, pred)
            img_merge = utils.add_row(img_merge, row)
        elif i == 8 * skip:
            filename = output_directory + '/comparison_' + str(epoch) + '.png'
            utils.save_image(img_merge, filename)

        if (i + 1) % args.print_freq == 0:
            print('Test: [{0}/{1}]\t'
                  't_GPU={gpu_time:.3f}({average.gpu_time:.3f})\n\t'
                  'mean_acc={result.mean_acc:.3f}({average.mean_acc:.3f})  '
                  'mean_iou={result.mean_iou:.3f}({average.mean_iou:.3f})'.
                  format(i + 1,
                         len(val_loader),
                         gpu_time=gpu_time,
                         result=result,
                         average=average_meter.average()))

    avg = average_meter.average()
    logger.add_scalar('Test/mean_acc', avg.mean_acc, epoch)
    logger.add_scalar('Test/mean_iou', avg.mean_iou, epoch)

    print('\n*\n'
          'mean_acc={average.mean_acc:.3f}\n'
          'mean_iou={average.mean_iou:.3f}\n'
          't_GPU={time:.3f}\n'.format(average=avg, time=avg.gpu_time))

    return avg, img_merge