def validate(loader, model, criterion_rgb, criterion_local, epoch=0):
    # batch_time = AverageMeter()
    losses = AverageMeter()
    metric = Metrics(max_depth=args.max_depth, disp=args.use_disp, normal=args.normal)
    score = AverageMeter()
    score_1 = AverageMeter()
    loss_rgb = torch.zeros(1)
    # Evaluate model
    model.eval()
    # Only forward pass, hence no grads needed
    with torch.no_grad():
        # end = time.time()
        for i, (input, gt) in enumerate(loader):
            if not args.no_cuda:
                input, gt = input.cuda(non_blocking=True), gt.cuda(non_blocking=True)
            prediction = model(input)

            if 'mod' in args.mod or 'stacked' in args.mod:
                loss = criterion_local(prediction[0], gt)
                loss_rgb = criterion_rgb(prediction[1], gt)
                loss += args.wrgb*loss_rgb
                prediction = prediction[0]
            else:
                loss = criterion_local(prediction, gt)

            losses.update(loss.item(), input.size(0))

            metric.calculate(prediction[:, 0:1], gt)
            score.update(metric.get_metric(args.metric), metric.num)
            score_1.update(metric.get_metric(args.metric_1), metric.num)

            if (i + 1) % args.print_freq == 0:
                print('Test: [{0}/{1}]\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Metric {score.val:.4f} ({score.avg:.4f})'.format(
                       i+1, len(loader), loss=losses,
                       score=score))

        # Synchronization needed
        if args.world_size>1:
            score.synchronize_between_processes()
            score_1.synchronize_between_processes()

        if args.evaluate:
            print("===> Average RMSE score on validation set is {:.4f}".format(score.avg))
            print("===> Average MAE score on validation set is {:.4f}".format(score_1.avg))
    return score.avg, score_1.avg, losses.avg
def main():
    global args
    args = parser.parse_args()
    if args.num_samples == 0:
        # Use all lidar points
        args.num_samples = None
    else:
        args.data_path = "" # path to precomputed 500 samples
        assert args.num_samples == 500
        print("changed path to samples500 dataset")
    if args.val_batch_size is None:
        args.val_batch_size = args.batch_size
    if args.seed:
        random.seed(args.seed)
        torch.manual_seed(args.seed)

    init_distributed_mode(args)

    if not args.no_cuda and not torch.cuda.is_available():
        raise Exception("No gpu available for usage")
    torch.backends.cudnn.benchmark = args.cudnn
    # Init model
    args.channels_in = 1 if args.input_type == 'depth' else 4
    model = Models.define_model(args.mod, args)
    define_init_weights(model, args.weight_init)
    if args.world_size > 1:
        model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
        print(model)

    # Load on gpu before passing params to optimizer
    if not args.no_cuda:
        model.cuda()
        if args.multi:
            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
            model = model.module

    save_id = '{}_{}_{}_{}_{}_batch{}_pretrain{}_wrgb{}_drop{}_patience{}_num_samples{}_multi{}_submod{}'.\
              format(args.mod, args.optimizer, args.loss_criterion,
                     args.learning_rate,
                     args.input_type, 
                     args.batch_size,
                     args.pretrained, args.wrgb, args.drop, 
                     args.lr_decay_iters, args.num_samples, args.multi, args.submod)


    # INIT optimizer/scheduler/loss criterion
    optimizer = define_optim(args.optimizer, model.parameters(), args.learning_rate, args.weight_decay)
    scheduler = define_scheduler(optimizer, args)

    # Optional to use different losses
    criterion_local = define_loss(args.loss_criterion)
    criterion_rgb = define_loss(args.loss_criterion)

    # INIT dataset
    dataset = Datasets.define_dataset(args.dataset, args.data_path, args.input_type)
    dataset.prepare_dataset()
    train_loader, train_sampler, valid_loader, valid_selection_loader = get_loader(args, dataset)

    # Resume training
    best_epoch = 0
    lowest_loss = np.inf
    args.save_path = os.path.join(args.save_path, save_id)
    mkdir_if_missing(args.save_path)
    log_file_name = 'log_train_start_0.txt'
    args.resume = first_run(args.save_path)
    if args.resume and not args.test_mode and not args.evaluate:
        path = os.path.join(args.save_path, 'checkpoint_model_epoch_{}.pth.tar'.format(int(args.resume)))
        if os.path.isfile(path):
            log_file_name = 'log_train_start_{}.txt'.format(args.resume)
            # stdout
            sys.stdout = Logger(os.path.join(args.save_path, log_file_name))
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(path)
            args.start_epoch = checkpoint['epoch']
            lowest_loss = checkpoint['loss']
            best_epoch = checkpoint['best epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            log_file_name = 'log_train_start_0.txt'
            # stdout
            sys.stdout = Logger(os.path.join(args.save_path, log_file_name))
            print("=> no checkpoint found at '{}'".format(path))

    # Only evaluate
    elif args.evaluate:
        print("Evaluate only")
        best_file_lst = glob.glob(os.path.join(args.save_path, 'model_best*'))
        if len(best_file_lst) != 0:
            best_file_name = best_file_lst[0]
            print(best_file_name)
            if os.path.isfile(best_file_name):
                sys.stdout = Logger(os.path.join(args.save_path, 'Evaluate.txt'))
                print("=> loading checkpoint '{}'".format(best_file_name))
                checkpoint = torch.load(best_file_name)
                model.load_state_dict(checkpoint['state_dict'])
            else:
                print("=> no checkpoint found at '{}'".format(best_file_name))
        else:
            print("=> no checkpoint found at due to empy list in folder {}".format(args.save_path))
        validate(valid_selection_loader, model, criterion_global, criterion_local)
        return

    # Start training from clean slate
    else:
        # Redirect stdout
        sys.stdout = Logger(os.path.join(args.save_path, log_file_name))

    # INIT MODEL
    print(40*"="+"\nArgs:{}\n".format(args)+40*"=")
    print("Init model: '{}'".format(args.mod))
    print("Number of parameters in model {} is {:.3f}M".format(args.mod.upper(), sum(tensor.numel() for tensor in model.parameters())/1e6))

    # Load pretrained state for cityscapes in GLOBAL net
    if args.pretrained and not args.resume:
        target_state = model.depthnet.state_dict()
        check = torch.load('erfnet_pretrained.pth')
        for name, val in check.items():
            # Exclude multi GPU prefix
            mono_name = name[7:] 
            if mono_name not in target_state:
                 continue
            try:
                target_state[mono_name].copy_(val)
            except RuntimeError:
                continue
        print('Successfully loaded pretrained model')

    # Start training
    for epoch in range(args.start_epoch, args.nepochs):
        print("\n => Start EPOCH {}".format(epoch + 1))
        print(datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
        print(args.save_path)
        # Adjust learning rate
        if args.lr_policy is not None and args.lr_policy != 'plateau':
            scheduler.step()
            lr = optimizer.param_groups[0]['lr']
            print('lr is set to {}'.format(lr))

        # Define container objects
        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        score_train = AverageMeter()
        score_train_1 = AverageMeter()
        metric_train = Metrics(max_depth=args.max_depth, disp=args.use_disp, normal=args.normal)

        # Train model for args.nepochs
        model.train()

        # compute timing
        end = time.time()

        # change randomization of train_sampler 
        train_sampler.set_epoch(epoch)

        # Load dataset
        for i, (input, gt) in enumerate(train_loader):

            # Time dataloader
            data_time.update(time.time() - end)

            # Put inputs on gpu if possible
            if not args.no_cuda:
                input, gt = input.cuda(non_blocking=True), gt.cuda(non_blocking=True)
            prediction = model(input)

            if 'mod' in args.mod or 'stacked' in args.mod:
                loss = criterion_local(prediction[0], gt)
                loss_rgb = criterion_rgb(prediction[1], gt)
                loss += args.wrgb*loss_rgb
                prediction = prediction[0]
            else:
                loss = criterion_local(prediction, gt)

            losses.update(loss.item(), input.size(0))
            metric_train.calculate(prediction[:, 0:1].detach(), gt.detach())
            score_train.update(metric_train.get_metric(args.metric), metric_train.num)
            score_train_1.update(metric_train.get_metric(args.metric_1), metric_train.num)

            # Clip gradients (usefull for instabilities or mistakes in ground truth)
            if args.clip_grad_norm != 0:
                nn.utils.clip_grad_norm(model.parameters(), args.clip_grad_norm)

            # Setup backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Time trainig iteration
            batch_time.update(time.time() - end)
            end = time.time()

            # Print info
            if (i + 1) % args.print_freq == 0:
                print('Epoch: [{0}][{1}/{2}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Metric {score.val:.4f} ({score.avg:.4f})'.format(
                       epoch+1, i+1, len(train_loader), batch_time=batch_time,
                       loss=losses,
                       score=score_train))

        if args.world_size>1:
            score_train.synchronize_between_processes()
            score_train_1.synchronize_between_processes()

        print("===> Average RMSE score on training set is {:.4f}".format(score_train.avg))
        print("===> Average MAE score on training set is {:.4f}".format(score_train_1.avg))
        # Evaulate model on validation set
        print("=> Start validation set")
        score_valid, score_valid_1, losses_valid = validate(valid_loader, model, criterion_rgb, criterion_local, epoch)
        print("===> Average RMSE score on validation set is {:.4f}".format(score_valid))
        print("===> Average MAE score on validation set is {:.4f}".format(score_valid_1))
        # Evaluate model on selected validation set
        if args.subset is None:
            print("=> Start selection validation set")
            score_selection, score_selection_1, losses_selection = validate(valid_selection_loader, model, criterion_rgb, criterion_local, epoch)
            total_score = score_selection
            print("===> Average RMSE score on selection set is {:.4f}".format(score_selection))
            print("===> Average MAE score on selection set is {:.4f}".format(score_selection_1))
        else:
            total_score = score_valid

        print("===> Last best score was RMSE of {:.4f} in epoch {}".format(lowest_loss,
                                                                           best_epoch))
        # Adjust lr if loss plateaued
        if args.lr_policy == 'plateau':
            scheduler.step(total_score)
            lr = optimizer.param_groups[0]['lr']
            print('LR plateaued, hence is set to {}'.format(lr))

        # File to keep latest epoch
        with open(os.path.join(args.save_path, 'first_run.txt'), 'w') as f:
            f.write(str(epoch))

        # Save model
        to_save = False
        if total_score < lowest_loss:

            to_save = True
            best_epoch = epoch+1
            lowest_loss = total_score

        if is_main_process():
            # save on master
            save_checkpoint({
                'epoch': epoch + 1,
                'best epoch': best_epoch,
                'arch': args.mod,
                'state_dict': model.state_dict(),
                'loss': lowest_loss,
                'optimizer': optimizer.state_dict()}, to_save, epoch)
    if not args.no_tb:
        writer.close()