예제 #1
0
파일: bts_main.py 프로젝트: starkgate/bts
def main_worker(gpu, ngpus_per_node, args):
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)

    # Create model
    model = BtsModel(args)
    model.train()
    model.decoder.apply(weights_init_xavier)
    set_misc(model)

    num_params = sum([np.prod(p.size()) for p in model.parameters()])
    print("Total number of parameters: {}".format(num_params))

    num_params_update = sum(
        [np.prod(p.shape) for p in model.parameters() if p.requires_grad])
    print("Total number of learning parameters: {}".format(num_params_update))

    if args.distributed:
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            args.batch_size = int(args.batch_size / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[args.gpu], find_unused_parameters=True)
        else:
            model.cuda()
            model = torch.nn.parallel.DistributedDataParallel(
                model, find_unused_parameters=True)
    else:
        model = torch.nn.DataParallel(model)
        model.cuda()

    if args.distributed:
        print("Model Initialized on GPU: {}".format(args.gpu))
    else:
        print("Model Initialized")

    global_step = 0
    best_eval_measures_lower_better = torch.zeros(6).cpu() + 1e3
    best_eval_measures_higher_better = torch.zeros(3).cpu()
    best_eval_steps = np.zeros(9, dtype=np.int32)

    # Training parameters
    optimizer = torch.optim.AdamW([{
        'params': model.module.encoder.parameters(),
        'weight_decay': args.weight_decay
    }, {
        'params': model.module.decoder.parameters(),
        'weight_decay': 0
    }],
                                  lr=args.learning_rate,
                                  eps=args.adam_eps)

    model_just_loaded = False
    if args.checkpoint_path != '':
        if os.path.isfile(args.checkpoint_path):
            print("Loading checkpoint '{}'".format(args.checkpoint_path))
            if args.gpu is None:
                checkpoint = torch.load(args.checkpoint_path)
            else:
                loc = 'cuda:{}'.format(args.gpu)
                checkpoint = torch.load(args.checkpoint_path, map_location=loc)
            global_step = checkpoint['global_step']
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            try:
                best_eval_measures_higher_better = checkpoint[
                    'best_eval_measures_higher_better'].cpu()
                best_eval_measures_lower_better = checkpoint[
                    'best_eval_measures_lower_better'].cpu()
                best_eval_steps = checkpoint['best_eval_steps']
            except KeyError:
                print("Could not load values for online evaluation")

            print("Loaded checkpoint '{}' (global_step {})".format(
                args.checkpoint_path, checkpoint['global_step']))
        else:
            print("No checkpoint found at '{}'".format(args.checkpoint_path))
        model_just_loaded = True

    if args.retrain:
        global_step = 0

    cudnn.benchmark = True

    dataloader = BtsDataLoader(args)
    dataloader_eval = BtsDataLoader(args)

    # Logging
    if not args.multiprocessing_distributed or (
            args.multiprocessing_distributed
            and args.rank % ngpus_per_node == 0):
        writer = SummaryWriter(os.path.join(args.log_directory,
                                            args.model_name, 'summaries'),
                               flush_secs=30)
        if args.do_online_eval:
            if args.eval_summary_directory != '':
                eval_summary_path = os.path.join(args.eval_summary_directory,
                                                 args.model_name)
            else:
                eval_summary_path = os.path.join(args.log_directory, 'eval')
            eval_summary_writer = SummaryWriter(eval_summary_path,
                                                flush_secs=30)

    silog_criterion = silog_loss(variance_focus=args.variance_focus)

    start_time = time.time()
    duration = 0

    num_log_images = args.batch_size
    end_learning_rate = args.end_learning_rate if args.end_learning_rate != -1 else 0.1 * args.learning_rate

    var_sum = [var.sum() for var in model.parameters() if var.requires_grad]
    var_cnt = len(var_sum)
    var_sum = np.sum(var_sum)

    print("Initial variables' sum: {:.3f}, avg: {:.3f}".format(
        var_sum, var_sum / var_cnt))

    steps_per_epoch = len(dataloader.data)
    num_total_steps = args.num_epochs * steps_per_epoch
    epoch = global_step // steps_per_epoch

    while epoch < args.num_epochs:
        if args.distributed:
            dataloader.train_sampler.set_epoch(epoch)

        for step, sample_batched in enumerate(dataloader.data):
            optimizer.zero_grad()
            before_op_time = time.time()

            mask = torch.autograd.Variable(sample_batched['mask'].cuda(
                args.gpu, non_blocking=True))
            image = torch.autograd.Variable(sample_batched['image'].cuda(
                args.gpu, non_blocking=True))
            focal = torch.autograd.Variable(sample_batched['focal'].cuda(
                args.gpu, non_blocking=True))
            depth_gt = torch.autograd.Variable(sample_batched['depth'].cuda(
                args.gpu, non_blocking=True))

            lpg8x8, lpg4x4, lpg2x2, reduc1x1, depth_est = model(image, focal)

            if args.dataset == 'nyu':
                mask = depth_gt > 0.1
            elif args.dataset == 'kitti':
                mask = depth_gt > 1.0

            loss = silog_criterion.forward(depth_est, depth_gt,
                                           mask.to('cuda:0'))
            loss.backward()
            for param_group in optimizer.param_groups:
                current_lr = (args.learning_rate - end_learning_rate) * (
                    1 - global_step / num_total_steps)**0.9 + end_learning_rate
                param_group['lr'] = current_lr

            optimizer.step()

            if not args.multiprocessing_distributed or (
                    args.multiprocessing_distributed
                    and args.rank % ngpus_per_node == 0):
                print(
                    '[epoch][s/s_per_e/gs]: [{}][{}/{}/{}], lr: {:.12f}, loss: {:.12f}'
                    .format(epoch, step, steps_per_epoch, global_step,
                            current_lr, loss))
                if np.isnan(loss.cpu().item()):
                    print('NaN in loss occurred. Aborting training.')
                    return -1

            duration += time.time() - before_op_time
            if global_step and global_step % args.log_freq == 0 and not model_just_loaded:
                var_sum = [
                    var.sum() for var in model.parameters()
                    if var.requires_grad
                ]
                var_cnt = len(var_sum)
                var_sum = np.sum(var_sum)
                examples_per_sec = args.batch_size / duration * args.log_freq
                duration = 0
                time_sofar = (time.time() - start_time) / 3600
                training_time_left = (num_total_steps / global_step -
                                      1.0) * time_sofar
                if not args.multiprocessing_distributed or (
                        args.multiprocessing_distributed
                        and args.rank % ngpus_per_node == 0):
                    print("{}".format(args.model_name))
                print_string = 'GPU: {} | examples/s: {:4.2f} | loss: {:.5f} | var sum: {:.3f} avg: {:.3f} | time elapsed: {:.2f}h | time left: {:.2f}h'
                print(
                    print_string.format(args.gpu, examples_per_sec, loss,
                                        var_sum.item(),
                                        var_sum.item() / var_cnt, time_sofar,
                                        training_time_left))

                if not args.multiprocessing_distributed or (
                        args.multiprocessing_distributed
                        and args.rank % ngpus_per_node == 0):
                    writer.add_scalar('silog_loss', loss, global_step)
                    writer.add_scalar('learning_rate', current_lr, global_step)
                    writer.add_scalar('var average',
                                      var_sum.item() / var_cnt, global_step)
                    depth_gt = torch.where(depth_gt < 1e-3, depth_gt * 0 + 1e3,
                                           depth_gt)
                    for i in range(num_log_images):
                        writer.add_image(
                            'depth_gt/image/{}'.format(i),
                            normalize_result(1 / depth_gt[i, :, :, :].data),
                            global_step)
                        writer.add_image(
                            'depth_est/image/{}'.format(i),
                            normalize_result(1 / depth_est[i, :, :, :].data),
                            global_step)
                        writer.add_image(
                            'reduc1x1/image/{}'.format(i),
                            normalize_result(1 / reduc1x1[i, :, :, :].data),
                            global_step)
                        writer.add_image(
                            'lpg2x2/image/{}'.format(i),
                            normalize_result(1 / lpg2x2[i, :, :, :].data),
                            global_step)
                        writer.add_image(
                            'lpg4x4/image/{}'.format(i),
                            normalize_result(1 / lpg4x4[i, :, :, :].data),
                            global_step)
                        writer.add_image(
                            'lpg8x8/image/{}'.format(i),
                            normalize_result(1 / lpg8x8[i, :, :, :].data),
                            global_step)
                        writer.add_image('image/image/{}'.format(i),
                                         inv_normalize(image[i, :, :, :]).data,
                                         global_step)
                    writer.flush()

            if not args.do_online_eval and global_step and global_step % args.save_freq == 0:
                if not args.multiprocessing_distributed or (
                        args.multiprocessing_distributed
                        and args.rank % ngpus_per_node == 0):
                    checkpoint = {
                        'global_step': global_step,
                        'model': model.state_dict(),
                        'optimizer': optimizer.state_dict()
                    }
                    torch.save(
                        checkpoint, args.log_directory + '/' +
                        args.model_name + '/model-{}'.format(global_step))

            if args.do_online_eval and global_step and global_step % args.eval_freq == 0 and not model_just_loaded:
                time.sleep(0.1)
                model.eval()
                eval_measures = online_eval(model, dataloader_eval, gpu,
                                            ngpus_per_node)
                if eval_measures is not None:
                    for i in range(9):
                        eval_summary_writer.add_scalar(eval_metrics[i],
                                                       eval_measures[i].cpu(),
                                                       int(global_step))
                        measure = eval_measures[i]
                        is_best = False
                        if i < 6 and measure < best_eval_measures_lower_better[
                                i]:
                            old_best = best_eval_measures_lower_better[i].item(
                            )
                            best_eval_measures_lower_better[i] = measure.item()
                            is_best = True
                        elif i >= 6 and measure > best_eval_measures_higher_better[
                                i - 6]:
                            old_best = best_eval_measures_higher_better[
                                i - 6].item()
                            best_eval_measures_higher_better[
                                i - 6] = measure.item()
                            is_best = True
                        if is_best:
                            old_best_step = best_eval_steps[i]
                            old_best_name = '/model-{}-best_{}_{:.5f}'.format(
                                old_best_step, eval_metrics[i], old_best)
                            model_path = args.log_directory + '/' + args.model_name + old_best_name
                            if os.path.exists(model_path):
                                command = 'rm {}'.format(model_path)
                                os.system(command)
                            best_eval_steps[i] = global_step
                            model_save_name = '/model-{}-best_{}_{:.5f}'.format(
                                global_step, eval_metrics[i], measure)
                            print('New best for {}. Saving model: {}'.format(
                                eval_metrics[i], model_save_name))
                            checkpoint = {
                                'global_step': global_step,
                                'model': model.state_dict(),
                                'optimizer': optimizer.state_dict(),
                                'best_eval_measures_higher_better':
                                best_eval_measures_higher_better,
                                'best_eval_measures_lower_better':
                                best_eval_measures_lower_better,
                                'best_eval_steps': best_eval_steps
                            }
                            torch.save(
                                checkpoint, args.log_directory + '/' +
                                args.model_name + model_save_name)
                    eval_summary_writer.flush()
                model.train()
                block_print()
                set_misc(model)
                enable_print()
            model_just_loaded = False
            global_step += 1

        epoch += 1
예제 #2
0
def main_worker(gpu, ngpus_per_node, args):
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)

    # Create model
    model = BtsModel(args)
    model.train()
    model.decoder.apply(weights_init_xavier)
    if args.bn_no_track_stats:
        print("Disabling tracking running stats in batch norm layers")
        model.apply(bn_init_as_tf)

    if args.fix_first_conv_blocks:
        if 'resne' in args.encoder:
            fixing_layers = [
                'base_model.conv1', 'base_model.layer1.0',
                'base_model.layer1.1', '.bn'
            ]
        else:
            fixing_layers = [
                'conv0', 'denseblock1.denselayer1', 'denseblock1.denselayer2',
                'norm'
            ]
        print("Fixing first two conv blocks")
    elif args.fix_first_conv_block:
        if 'resne' in args.encoder:
            fixing_layers = ['base_model.conv1', 'base_model.layer1.0', '.bn']
        else:
            fixing_layers = ['conv0', 'denseblock1.denselayer1', 'norm']
        print("Fixing first conv block")
    else:
        if 'resne' in args.encoder:
            fixing_layers = ['base_model.conv1', '.bn']
        else:
            fixing_layers = ['conv0', 'norm']
        print("Fixing first conv layer")

    for name, child in model.named_children():
        if not 'encoder' in name:
            continue
        for name2, parameters in child.named_parameters():
            # print(name, name2)
            if any(x in name2 for x in fixing_layers):
                parameters.requires_grad = False

    num_params = sum([np.prod(p.size()) for p in model.parameters()])
    print("Total number of parameters: {}".format(num_params))

    num_params_update = sum(
        [np.prod(p.shape) for p in model.parameters() if p.requires_grad])
    print("Total number of learning parameters: {}".format(num_params_update))

    if args.distributed:
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            args.batch_size = int(args.batch_size / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[args.gpu], find_unused_parameters=True)
        else:
            model.cuda()
            model = torch.nn.parallel.DistributedDataParallel(
                model, find_unused_parameters=True)
    else:
        model = torch.nn.DataParallel(model)
        model.cuda()

    if args.distributed:
        print("Model Initialized on GPU: {}".format(args.gpu))
    else:
        print("Model Initialized")

    global_step = 0

    # Training parameters
    optimizer = torch.optim.AdamW([{
        'params': model.module.encoder.parameters(),
        'weight_decay': args.weight_decay
    }, {
        'params': model.module.decoder.parameters(),
        'weight_decay': 0
    }],
                                  lr=args.learning_rate,
                                  eps=1e-6)
    # optimizer = torch.optim.AdamW(model.parameters(), weight_decay=args.weight_decay, lr=args.learning_rate, eps=1e-3)

    if args.checkpoint_path != '':
        if os.path.isfile(args.checkpoint_path):
            print("Loading checkpoint '{}'".format(args.checkpoint_path))
            if args.gpu is None:
                checkpoint = torch.load(args.checkpoint_path)
            else:
                loc = 'cuda:{}'.format(args.gpu)
                checkpoint = torch.load(args.checkpoint_path, map_location=loc)
            global_step = checkpoint['global_step']
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("Loaded checkpoint '{}' (global_step {})".format(
                args.checkpoint_path, checkpoint['global_step']))
        else:
            print("No checkpoint found at '{}'".format(args.checkpoint_path))

    if args.retrain:
        global_step = 0

    cudnn.benchmark = True

    dataloader = BtsDataLoader(args)

    # Logging
    if not args.multiprocessing_distributed or (
            args.multiprocessing_distributed
            and args.rank % ngpus_per_node == 0):
        writer = SummaryWriter(args.log_directory + '/' + args.model_name +
                               '/summaries',
                               flush_secs=30)

    silog_criterion = silog_loss(variance_focus=args.variance_focus)

    start_time = time.time()
    duration = 0
    log_freq = 100
    save_freq = 500

    num_log_images = args.batch_size
    end_learning_rate = args.end_learning_rate if args.end_learning_rate != -1 else 0.1 * args.learning_rate

    var_sum = [var.sum() for var in model.parameters() if var.requires_grad]
    var_cnt = len(var_sum)
    var_sum = np.sum(var_sum)

    print("Initial variables' sum: {:.3f}, avg: {:.3f}".format(
        var_sum, var_sum / var_cnt))

    steps_per_epoch = len(dataloader.data)
    num_total_steps = args.num_epochs * steps_per_epoch

    epoch = global_step // steps_per_epoch

    while epoch < args.num_epochs:
        if args.distributed:
            dataloader.train_sampler.set_epoch(epoch)

        for step, sample_batched in enumerate(dataloader.data):
            optimizer.zero_grad()
            before_op_time = time.time()

            image = torch.autograd.Variable(sample_batched['image'].cuda(
                args.gpu, non_blocking=True))
            focal = torch.autograd.Variable(sample_batched['focal'].cuda(
                args.gpu, non_blocking=True))
            depth_gt = torch.autograd.Variable(sample_batched['depth'].cuda(
                args.gpu, non_blocking=True))

            lpg8x8, lpg4x4, lpg2x2, reduc1x1, depth_est = model(image, focal)

            if args.dataset == 'nyu':
                mask = depth_gt > 0.1
            else:
                mask = depth_gt > 1.0

            loss = silog_criterion.forward(depth_est, depth_gt,
                                           mask.to(torch.bool))
            loss.backward()
            for param_group in optimizer.param_groups:
                current_lr = (args.learning_rate - end_learning_rate) * (
                    1 - global_step / num_total_steps)**0.9 + end_learning_rate
                param_group['lr'] = current_lr

            optimizer.step()

            if not args.multiprocessing_distributed or (
                    args.multiprocessing_distributed
                    and args.rank % ngpus_per_node == 0):
                print(
                    '[epoch][s/s_per_e/gs]: [{}][{}/{}/{}], lr: {:.12f}, loss: {:.12f}'
                    .format(epoch, step, steps_per_epoch, global_step,
                            current_lr, loss))

            duration += time.time() - before_op_time
            if global_step and global_step % log_freq == 0:
                var_sum = [
                    var.sum() for var in model.parameters()
                    if var.requires_grad
                ]
                var_cnt = len(var_sum)
                var_sum = np.sum(var_sum)
                examples_per_sec = args.batch_size / duration * log_freq
                duration = 0
                time_sofar = (time.time() - start_time) / 3600
                training_time_left = (num_total_steps / global_step -
                                      1.0) * time_sofar
                if not args.multiprocessing_distributed or (
                        args.multiprocessing_distributed
                        and args.rank % ngpus_per_node == 0):
                    print("{}".format(args.model_name))
                print_string = 'GPU: {} | examples/s: {:4.2f} | loss: {:.5f} | var sum: {:.3f} avg: {:.3f} | time elapsed: {:.2f}h | time left: {:.2f}h'
                print(
                    print_string.format(args.gpu, examples_per_sec, loss,
                                        var_sum.item(),
                                        var_sum.item() / var_cnt, time_sofar,
                                        training_time_left))

                if not args.multiprocessing_distributed or (
                        args.multiprocessing_distributed
                        and args.rank % ngpus_per_node == 0):
                    writer.add_scalar('silog_loss', loss, global_step)
                    writer.add_scalar('learning_rate', current_lr, global_step)
                    writer.add_scalar('var average',
                                      var_sum.item() / var_cnt, global_step)
                    depth_gt = torch.where(depth_gt < 1e-3, depth_gt * 0 + 1e3,
                                           depth_gt)
                    for i in range(num_log_images):
                        writer.add_image(
                            'depth_gt/image/{}'.format(i),
                            normalize_result(1 / depth_gt[i, :, :, :].data),
                            global_step)
                        writer.add_image(
                            'depth_est/image/{}'.format(i),
                            normalize_result(1 / depth_est[i, :, :, :].data),
                            global_step)
                        writer.add_image(
                            'reduc1x1/image/{}'.format(i),
                            normalize_result(1 / reduc1x1[i, :, :, :].data),
                            global_step)
                        writer.add_image(
                            'lpg2x2/image/{}'.format(i),
                            normalize_result(1 / lpg2x2[i, :, :, :].data),
                            global_step)
                        writer.add_image(
                            'lpg4x4/image/{}'.format(i),
                            normalize_result(1 / lpg4x4[i, :, :, :].data),
                            global_step)
                        writer.add_image(
                            'lpg8x8/image/{}'.format(i),
                            normalize_result(1 / lpg8x8[i, :, :, :].data),
                            global_step)
                        writer.add_image('image/image/{}'.format(i),
                                         inv_normalize(image[i, :, :, :]).data,
                                         global_step)
                    writer.flush()

            if global_step and global_step % save_freq == 0:
                if not args.multiprocessing_distributed or (
                        args.multiprocessing_distributed
                        and args.rank % ngpus_per_node == 0):
                    checkpoint = {
                        'global_step': global_step,
                        'model': model.state_dict(),
                        'optimizer': optimizer.state_dict()
                    }
                    torch.save(
                        checkpoint, args.log_directory + '/' +
                        args.model_name + '/model-{}'.format(global_step))

            global_step += 1

        epoch += 1