Example #1
0
def main():
    global args, best_error, n_iter
    args = parser.parse_args()

    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])

    train_transform = custom_transforms.Compose([
        custom_transforms.RandomRotate(),
        custom_transforms.RandomHorizontalFlip(),
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor(), normalize
    ])

    train_set = SequenceFolder(args.data,
                               transform=train_transform,
                               seed=args.seed,
                               train=True,
                               sequence_length=args.sequence_length)
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=4,
                                               shuffle=True,
                                               num_workers=4,
                                               pin_memory=True,
                                               drop_last=True)
    print(len(train_loader))
Example #2
0
def get_train_transforms(normalize):
  train_transforms = []
  train_transforms.append(transforms.Scale(160))
  train_transforms.append(transforms.RandomHorizontalFlip())
  train_transforms.append(transforms.RandomColor(0.15))
  train_transforms.append(transforms.RandomRotate(15))
  train_transforms.append(transforms.RandomSizedCrop(128))
  train_transforms.append(transforms.ToTensor())
  train_transforms.append(normalize)
  train_transforms = transforms.Compose(train_transforms)
  return train_transforms
Example #3
0
		return sample


if __name__ == '__main__':
	print(os.getcwd())
	# from dataloaders import custom_transforms as tr
	import custom_transforms as tr
	from torch.utils.data import DataLoader
	from torchvision import transforms
	import matplotlib.pyplot as plt

	composed_transforms_tr = transforms.Compose([
		tr.RandomHorizontalFlip(),
		tr.RandomScale((0.5, 0.75)),
		tr.RandomCrop((512, 1024)),
		tr.RandomRotate(5),
		tr.ToTensor()])

	msrab_train= MSRAB(split='train',transform=composed_transforms_tr)

	dataloader = DataLoader(msrab_train, batch_size=2, shuffle=True, num_workers=2)

	for ii, sample in enumerate(dataloader):
		print(ii, sample["image"].size(), sample["label"].size(), type(sample["image"]), type(sample["label"]))
		for jj in range(sample["image"].size()[0]):
			img = sample['image'].numpy()
			gt = sample['label'].numpy()
			tmp = np.array(gt[jj]*255.0).astype(np.uint8)
			tmp = np.squeeze(tmp, axis=0)
			tmp = np.expand_dims(tmp, axis=2)
			segmap = np.concatenate((tmp,tmp,tmp), axis=2)
Example #4
0
        return _img, _target

    def __str__(self):
        return 'ISPRS(split=' + str(self.split) + ')'


if __name__ == '__main__':
    import custom_transforms as tr
    from utils import decode_segmap
    from torch.utils.data import DataLoader
    from torchvision import transforms
    import matplotlib.pyplot as plt

    composed_transforms_tr = transforms.Compose([tr.RandomHorizontalFlip(),
                                                 tr.RandomCrop(512),
                                                 tr.RandomRotate(15),
                                                 tr.ToTensor()]
                                                )

    isprs_train = ISPRSSegmentation(split='train', transform=composed_transforms_tr)

    dataloader = DataLoader(isprs_train, batch_size=5, shuffle=True, num_workers=2)

    for ii, sample in enumerate(dataloader):
        for jj in range(sample["image"].size()[0]):
            img = sample['image'].numpy()
            gt = sample['label'].numpy()
            tmp = np.array(gt[jj]).astype(np.uint8)
            tmp = np.squeeze(tmp, axis=0)
            segmap = decode_segmap(tmp, dataset='ISPRS')
            img_tmp = np.transpose(img[jj], axes=[1, 2, 0]).astype(np.uint8)
Example #5
0
def main(args):
  # parse args
  best_acc1 = 0.0

  if args.gpu >= 0:
    print("Use GPU: {}".format(args.gpu))
  else:
    print('You are using CPU for computing!',
          'Yet we assume you are using a GPU.',
          'You will NOT be able to switch between CPU and GPU training!')

  # set up the model + loss
  if args.use_custom_conv:
    print("Using custom convolutions in the network")
    model = default_model(conv_op=CustomConv2d, num_classes=100)
  else:
    model = default_model(num_classes=100)
  model_arch = "simplenet"
  # model_arch = "simplenet_batchnorm2d"
  # model_arch = "resnet18"
  # model_arch = "resnet34"
  criterion = nn.CrossEntropyLoss()
  # put everthing to gpu
  if args.gpu >= 0:
    model = model.cuda(args.gpu)
    criterion = criterion.cuda(args.gpu)

  # setup the optimizer
  optimizer = torch.optim.SGD(model.parameters(), args.lr,
                momentum=args.momentum,
                weight_decay=args.weight_decay)

  # resume from a checkpoint?
  if args.resume:
    if os.path.isfile(args.resume):
      print("=> loading checkpoint '{}'".format(args.resume))
      checkpoint = torch.load(args.resume)
      args.start_epoch = checkpoint['epoch']
      best_acc1 = checkpoint['best_acc1']
      model.load_state_dict(checkpoint['state_dict'])
      if args.gpu < 0:
        model = model.cpu()
      else:
        model = model.cuda(args.gpu)
      # only load the optimizer if necessary
      if (not args.evaluate) and (not args.attack):
        optimizer.load_state_dict(checkpoint['optimizer'])
      print("=> loaded checkpoint '{}' (epoch {}, acc1 {})"
          .format(args.resume, checkpoint['epoch'], best_acc1))
    else:
      print("=> no checkpoint found at '{}'".format(args.resume))

  # set up transforms for data augmentation
  normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                       std=[0.229, 0.224, 0.225])
  train_transforms = []
  train_transforms.append(transforms.Scale(160))
  train_transforms.append(transforms.RandomHorizontalFlip())
  train_transforms.append(transforms.RandomColor(0.15))
  train_transforms.append(transforms.RandomRotate(15))
  train_transforms.append(transforms.RandomSizedCrop(128))
  train_transforms.append(transforms.ToTensor())
  train_transforms.append(normalize)
  train_transforms = transforms.Compose(train_transforms)
  # val transofrms
  val_transforms=[]
  val_transforms.append(transforms.Scale(160, interpolations=None))
  val_transforms.append(transforms.ToTensor())
  val_transforms.append(normalize)
  val_transforms = transforms.Compose(val_transforms)
  if (not args.evaluate) and (not args.attack):
    print("Training time data augmentations:")
    print(train_transforms)

  # setup dataset and dataloader
  train_dataset = MiniPlacesLoader(args.data_folder,
                  split='train', transforms=train_transforms)
  val_dataset = MiniPlacesLoader(args.data_folder,
                  split='val', transforms=val_transforms)

  train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=args.batch_size, shuffle=True,
    num_workers=args.workers, pin_memory=True, sampler=None, drop_last=True)
  val_loader = torch.utils.data.DataLoader(
    val_dataset, batch_size=100, shuffle=False,
    num_workers=args.workers, pin_memory=True, sampler=None, drop_last=False)

  # testing only
  if (args.evaluate==args.attack) and args.evaluate:
    print("Cann't set evaluate and attack to True at the same time!")
    return

  # set up visualizer
  if args.vis:
    visualizer = default_attention(criterion)
  else:
    visualizer = None

  # evaluation
  if args.resume and args.evaluate:
    print("Testing the model ...")
    cudnn.deterministic = True
    validate(val_loader, model, -1, args, visualizer=visualizer)
    return

  # attack
  if args.resume and args.attack:
    print("Generating adversarial samples for the model ..")
    cudnn.deterministic = True
    validate(val_loader, model, -1, args,
             attacker=default_attack(criterion),
             visualizer=visualizer)
    return

  # enable cudnn benchmark
  cudnn.enabled = True
  cudnn.benchmark = True

  # warmup the training
  if (args.start_epoch == 0) and (args.warmup_epochs > 0):
    print("Warmup the training ...")
    for epoch in range(0, args.warmup_epochs):
      train(train_loader, model, criterion, optimizer, epoch, "warmup", args)

  # start the training
  print("Training the model ...")
  for epoch in range(args.start_epoch, args.epochs):
    # train for one epoch
    train(train_loader, model, criterion, optimizer, epoch, "train", args)

    # evaluate on validation set
    acc1 = validate(val_loader, model, epoch, args)

    # remember best acc@1 and save checkpoint
    is_best = acc1 > best_acc1
    best_acc1 = max(acc1, best_acc1)
    save_checkpoint({
      'epoch': epoch + 1,
      'model_arch': model_arch,
      'state_dict': model.state_dict(),
      'best_acc1': best_acc1,
      'optimizer' : optimizer.state_dict(),
    }, is_best)
Example #6
0
def main():
    global args, best_error, n_iter
    args = parser.parse_args()
    save_path = Path(args.name)
    args.save_path = 'checkpoints' / save_path
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)

    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])
    train_transform = custom_transforms.Compose([
        custom_transforms.RandomRotate(),
        custom_transforms.RandomHorizontalFlip(),
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor(), normalize
    ])
    training_writer = SummaryWriter(args.save_path)

    intrinsics = np.array(
        [542.822841, 0, 315.593520, 0, 542.576870, 237.756098, 0, 0,
         1]).astype(np.float32).reshape((3, 3))
    train_set = SequenceFolder(root=args.dataset_dir,
                               intrinsics=intrinsics,
                               transform=train_transform,
                               seed=args.seed,
                               train=True,
                               sequence_length=args.sequence_length)

    print('{} samples found in {} train scenes'.format(len(train_set),
                                                       len(train_set.scenes)))
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               drop_last=True)

    print("=> creating model")
    mask_net = MaskNet6.MaskNet6().cuda()
    flow_net = back2future.Model(nlevels=args.nlevels).cuda()
    pose_net = PoseNetB6.PoseNetB6().cuda()

    if args.pretrained_mask:
        print("=> using pre-trained weights for explainabilty and pose net")
        weights = torch.load(args.pretrained_mask)
        mask_net.load_state_dict(weights['state_dict'])
    else:
        mask_net.init_weights()
    mask_net = torch.nn.DataParallel(mask_net)

    if args.pretrained_pose:
        print("=> using pre-trained weights for explainabilty and pose net")
        weights = torch.load(args.pretrained_pose)
        pose_net.load_state_dict(weights['state_dict'])
    else:
        pose_net.init_weights()

    if args.pretrained_flow:
        print("=> using pre-trained weights for explainabilty and pose net")
        weights = torch.load(args.pretrained_flow)
        flow_net.load_state_dict(weights['state_dict'])
    else:
        flow_net.init_weights()

    print('=> setting adam solver')
    parameters = chain(mask_net.parameters(), pose_net.parameters(),
                       flow_net.parameters())
    optimizer = torch.optim.Adam(parameters,
                                 args.lr,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    # training
    best_error = 0
    train_loss = 0
    for epoch in tqdm(range(args.epochs)):
        if args.fix_flownet:
            for fparams in flow_net.parameters():
                fparams.requires_grad = False

        if args.fix_masknet:
            for fparams in mask_net.parameters():
                fparams.requires_grad = False

        if args.fix_posenet:
            for fparams in pose_net.parameters():
                fparams.requires_grad = False

        is_best = train_loss < best_error
        best_error = min(best_error, train_loss)
        save_checkpoint(args.save_path, {
            'epoch': epoch + 1,
            'state_dict': mask_net.module.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': pose_net.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': flow_net.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': optimizer.state_dict()
        }, is_best)
        train_loss = train(train_loader, mask_net, pose_net, flow_net,
                           optimizer, args.epoch_size, training_writer)
Example #7
0
def main():
    global args, best_error, n_iter
    n_iter = 0
    best_error = 0

    args = parser.parse_args()
    '''
    args = ['--name', 'deemo', '--FCCMnet', 'PatchWiseNetwork', 
                        '--dataset_dir', '/notebooks/FCCM/pre-process/pre-process', '--label_dir', '/notebooks/FCCM', 
                        '--batch-size', '4', '--epochs','100', '--lr', '1e-4' ]
    '''
    save_path = Path(args.name)
    args.save_path = 'checkpoints'/save_path 
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()

    train_writer = SummaryWriter(args.save_path)
    torch.manual_seed(args.seed)



    train_transform = custom_transforms.Compose([
          custom_transforms.RandomRotate(),
          custom_transforms.RandomHorizontalFlip(),
          custom_transforms.RandomScaleCrop(),
          custom_transforms.ArrayToTensor() ])

    train_set = Generate_train_set(
        root = args.dataset_dir,
        label_root = args.label_dir,
        transform=train_transform,
        seed=args.seed,
        train=True
    )

    val_set = Generate_val_set(
        root = args.dataset_dir,
        label_root = args.label_dir,
        seed=args.seed
    )

    print('{} samples found in {} train scenes'.format(len(train_set), len(train_set.scenes)))
    print('{} samples found in {} val scenes'.format(len(val_set), len(val_set.scenes)))
    
    train_loader = torch.utils.data.DataLoader(
        train_set, batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True, drop_last=True)
    val_loader = torch.utils.data.DataLoader(
        val_set, batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True, drop_last=True)

    print("=> creating model")
    if args.FCCMnet == 'VGG':
        FCCM_net = models_inpytorch.vgg16(num_classes=19)
        FCCM_net.features[0]=nn.Conv2d(1, 64, kernel_size=3, padding=1)
    if args.FCCMnet == 'ResNet':
        FCCM_net = models.resnet18(num_classes=19)
    if args.FCCMnet == 'CatNet':
        FCCM_net = models.catnet18(num_classes=19)
    if args.FCCMnet == 'CatNet_FCCM':
        FCCM_net = models.catnet1(num_classes=19)
    if args.FCCMnet == 'ResNet_FCCM':
        FCCM_net = models.resnet1(num_classes=19)
    if args.FCCMnet == 'ImageWise':
        FCCM_net = models.ImageWiseNetwork()
    if args.FCCMnet == 'PatchWise':
        FCCM_net = models.PatchWiseNetwork()
    if args.FCCMnet == 'Baseline':
        FCCM_net = models.Baseline()



    FCCM_net = FCCM_net.cuda()

    if args.pretrained_model:
        print("=> using pre-trained weights for net")
        weights = torch.load(args.pretrained_model)
        FCCM_net.load_state_dict(weights['state_dict'])


    cudnn.benchmark = True
    FCCM_net = torch.nn.DataParallel(FCCM_net)

    print('=> setting adam solver')

    parameters = chain(FCCM_net.parameters())
    optimizer = torch.optim.Adam(parameters, args.lr,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)


    is_best = False
    best_error = float("inf") 
    FCCM_net.train()
    loss = 0
    for epoch in tqdm(range(args.epochs)):
        is_best = loss <= best_error
        best_error = min(best_error, loss)

        save_checkpoint(
                args.save_path, {
                    'epoch': epoch + 1,
                    'state_dict': FCCM_net.state_dict()
                },  {                
                    'epoch': epoch + 1,
                    'state_dict': optimizer.state_dict()
                }, is_best)
        validation(val_loader, FCCM_net, epoch, train_writer)
        loss = train(train_loader, FCCM_net, optimizer, args.epoch_size, train_writer)
Example #8
0
File: main2.py Project: cudnn/cc-1
def main():
    global global_vars_dict
    args = global_vars_dict['args']
    best_error = -1  #best model choosing

    #mkdir
    timestamp = datetime.datetime.now().strftime("%m-%d-%H:%M")

    args.save_path = Path('checkpoints') / Path(args.data_dir).stem / timestamp
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)
    if args.alternating:
        args.alternating_flags = np.array([False, False, True])
    #mk writers
    tb_writer = SummaryWriter(args.save_path)

    # Data loading code
    flow_loader_h, flow_loader_w = 256, 832

    if args.data_normalization == 'global':
        normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                                std=[0.5, 0.5, 0.5])
    elif args.data_normalization == 'local':
        normalize = custom_transforms.NormalizeLocally()

    if args.fix_flownet:
        train_transform = custom_transforms.Compose([
            custom_transforms.RandomHorizontalFlip(),
            custom_transforms.RandomScaleCrop(),
            custom_transforms.ArrayToTensor(), normalize
        ])
    else:
        train_transform = custom_transforms.Compose([
            custom_transforms.RandomRotate(),
            custom_transforms.RandomHorizontalFlip(),
            custom_transforms.RandomScaleCrop(),
            custom_transforms.ArrayToTensor(), normalize
        ])

    valid_transform = custom_transforms.Compose(
        [custom_transforms.ArrayToTensor(), normalize])

    valid_flow_transform = custom_transforms.Compose([
        custom_transforms.Scale(h=flow_loader_h, w=flow_loader_w),
        custom_transforms.ArrayToTensor(), normalize
    ])

    print("=> fetching scenes in '{}'".format(args.data_dir))

    #train set, loader only建立一个
    if args.dataset_format == 'stacked':
        from datasets.stacked_sequence_folders import SequenceFolder
    elif args.dataset_format == 'sequential':
        from datasets.sequence_folders import SequenceFolder
        train_set = SequenceFolder(  #mc data folder
            args.data_dir,
            transform=train_transform,
            seed=args.seed,
            train=True,
            sequence_length=args.sequence_length,  #5
            target_transform=None)
    elif args.dataset_format == 'sequential_with_gt':  # with all possible gt
        from datasets.sequence_mc import SequenceFolder
        train_set = SequenceFolder(  # mc data folder
            args.data_dir,
            transform=train_transform,
            seed=args.seed,
            train=True,
            sequence_length=args.sequence_length,  # 5
            target_transform=None)
    else:
        return

    if args.DEBUG:
        train_set.__len__ = 32
        train_set.samples = train_set.samples[:32]
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               drop_last=True)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

#val set,loader 挨个建立

# if no Groundtruth is avalaible, Validation set is the same type as training set to measure photometric loss from warping
    if args.val_without_gt:
        from datasets.sequence_folders2 import SequenceFolder  #就多了一级文件夹
        val_set_without_gt = SequenceFolder(  #只有图
            args.data_dir,
            transform=valid_transform,
            seed=None,
            train=False,
            sequence_length=args.sequence_length,
            target_transform=None)
        val_loader = torch.utils.data.DataLoader(val_set_without_gt,
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=True,
                                                 drop_last=True)

    if args.val_with_depth_gt:
        from datasets.validation_folders2 import ValidationSet

        val_set_with_depth_gt = ValidationSet(args.data_dir,
                                              transform=valid_transform)

        val_loader_depth = torch.utils.data.DataLoader(
            val_set_with_depth_gt,
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=args.workers,
            pin_memory=True,
            drop_last=True)

    if args.val_with_flow_gt:  #暂时没有
        from datasets.validation_flow import ValidationFlow
        val_flow_set = ValidationFlow(root=args.kitti_dir,
                                      sequence_length=args.sequence_length,
                                      transform=valid_flow_transform)
        val_flow_loader = torch.utils.data.DataLoader(
            val_flow_set,
            batch_size=1,
            # batch size is 1 since images in kitti have different sizes
            shuffle=False,
            num_workers=args.workers,
            pin_memory=True,
            drop_last=True)

    print('{} samples found in {} train scenes'.format(len(train_set),
                                                       len(train_set.scenes)))
    if args.val_without_gt:
        print('{} samples found in {} valid scenes'.format(
            len(val_set_without_gt), len(val_set_without_gt.scenes)))

#1 create model
    print("=> creating model")
    #1.1 disp_net
    disp_net = getattr(models, args.dispnet)().cuda()
    output_exp = True  #args.mask_loss_weight > 0
    if not output_exp:
        print("=> no mask loss, PoseExpnet will only output pose")
    #1.2 pose_net
    pose_net = getattr(models, args.posenet)(nb_ref_imgs=args.sequence_length -
                                             1).cuda()

    #1.3.flow_net
    if args.flownet == 'SpyNet':
        flow_net = getattr(models,
                           args.flownet)(nlevels=args.nlevels,
                                         pre_normalization=normalize).cuda()
    elif args.flownet == 'FlowNetC6':  #flonwtc6
        flow_net = getattr(models, args.flownet)(nlevels=args.nlevels).cuda()
    elif args.flownet == 'FlowNetS':
        flow_net = getattr(models, args.flownet)(nlevels=args.nlevels).cuda()
    elif args.flownet == 'Back2Future':
        flow_net = getattr(models, args.flownet)(nlevels=args.nlevels).cuda()

    # 1.4 mask_net
    mask_net = getattr(models,
                       args.masknet)(nb_ref_imgs=args.sequence_length - 1,
                                     output_exp=True).cuda()

    #2 载入参数
    #2.1 pose
    if args.pretrained_pose:
        print("=> using pre-trained weights for explainabilty and pose net")
        weights = torch.load(args.pretrained_pose)
        pose_net.load_state_dict(weights['state_dict'])
    else:
        pose_net.init_weights()

    if args.pretrained_mask:
        print("=> using pre-trained weights for explainabilty and pose net")
        weights = torch.load(args.pretrained_mask)
        mask_net.load_state_dict(weights['state_dict'])
    else:
        mask_net.init_weights()

    # import ipdb; ipdb.set_trace()
    if args.pretrained_disp:
        print("=> using pre-trained weights from {}".format(
            args.pretrained_disp))
        weights = torch.load(args.pretrained_disp)
        disp_net.load_state_dict(weights['state_dict'])
    else:
        disp_net.init_weights()

    if args.pretrained_flow:
        print("=> using pre-trained weights for FlowNet")
        weights = torch.load(args.pretrained_flow)
        flow_net.load_state_dict(weights['state_dict'])
    else:
        flow_net.init_weights()

    if args.resume:
        print("=> resuming from checkpoint")
        dispnet_weights = torch.load(args.save_path /
                                     'dispnet_checkpoint.pth.tar')
        posenet_weights = torch.load(args.save_path /
                                     'posenet_checkpoint.pth.tar')
        masknet_weights = torch.load(args.save_path /
                                     'masknet_checkpoint.pth.tar')
        flownet_weights = torch.load(args.save_path /
                                     'flownet_checkpoint.pth.tar')
        disp_net.load_state_dict(dispnet_weights['state_dict'])
        pose_net.load_state_dict(posenet_weights['state_dict'])
        flow_net.load_state_dict(flownet_weights['state_dict'])
        mask_net.load_state_dict(masknet_weights['state_dict'])

    # import ipdb; ipdb.set_trace()
    cudnn.benchmark = True
    disp_net = torch.nn.DataParallel(disp_net)
    pose_net = torch.nn.DataParallel(pose_net)
    mask_net = torch.nn.DataParallel(mask_net)
    flow_net = torch.nn.DataParallel(flow_net)

    print('=> setting adam solver')

    parameters = chain(disp_net.parameters(), pose_net.parameters(),
                       mask_net.parameters(), flow_net.parameters())
    optimizer = torch.optim.Adam(parameters,
                                 args.lr,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    if args.resume and (args.save_path /
                        'optimizer_checkpoint.pth.tar').exists():
        print("=> loading optimizer from checkpoint")
        optimizer_weights = torch.load(args.save_path /
                                       'optimizer_checkpoint.pth.tar')
        optimizer.load_state_dict(optimizer_weights['state_dict'])

    with open(args.save_path / args.log_summary, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'validation_loss'])

    with open(args.save_path / args.log_full, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow([
            'train_loss', 'photo_cam_loss', 'photo_flow_loss',
            'explainability_loss', 'smooth_loss'
        ])

    #
    if args.log_terminal:
        logger = TermLogger(n_epochs=args.epochs,
                            train_size=min(len(train_loader), args.epoch_size),
                            valid_size=len(val_loader_depth))
        logger.epoch_bar.start()
    else:
        logger = None

#预先评估下

    if args.pretrained_disp or args.evaluate:
        logger.reset_valid_bar()
        if args.val_without_gt:
            pass
            #val_loss = validate_without_gt(val_loader,disp_net,pose_net,mask_net,flow_net,epoch=0, logger=logger, tb_writer=tb_writer,nb_writers=3,global_vars_dict = global_vars_dict)
            #val_loss =0

        if args.val_with_depth_gt:
            pass
            depth_errors, depth_error_names = validate_depth_with_gt(
                val_loader_depth,
                disp_net,
                epoch=0,
                logger=logger,
                tb_writer=tb_writer,
                global_vars_dict=global_vars_dict)


#3. main cycle
    for epoch in range(1, args.epochs):  #epoch 0 在第没入循环之前已经测试了.
        #3.1 四个子网络,训练哪几个
        if args.fix_flownet:
            for fparams in flow_net.parameters():
                fparams.requires_grad = False

        if args.fix_masknet:
            for fparams in mask_net.parameters():
                fparams.requires_grad = False

        if args.fix_posenet:
            for fparams in pose_net.parameters():
                fparams.requires_grad = False

        if args.fix_dispnet:
            for fparams in disp_net.parameters():
                fparams.requires_grad = False

        if args.log_terminal:
            logger.epoch_bar.update(epoch)
            logger.reset_train_bar()
        #validation data
        flow_error_names = ['no']
        flow_errors = [0]
        errors = [0]
        error_names = ['no error names depth']
        print('\nepoch [{}/{}]\n'.format(epoch + 1, args.epochs))
        #3.2 train for one epoch---------
        #train_loss=0
        train_loss = train_gt(train_loader, disp_net, pose_net, mask_net,
                              flow_net, optimizer, logger, tb_writer,
                              global_vars_dict)

        #3.3 evaluate on validation set-----

        if args.val_without_gt:
            val_loss = validate_without_gt(val_loader,
                                           disp_net,
                                           pose_net,
                                           mask_net,
                                           flow_net,
                                           epoch=0,
                                           logger=logger,
                                           tb_writer=tb_writer,
                                           nb_writers=3,
                                           global_vars_dict=global_vars_dict)

        if args.val_with_depth_gt:
            depth_errors, depth_error_names = validate_depth_with_gt(
                val_loader_depth,
                disp_net,
                epoch=epoch,
                logger=logger,
                tb_writer=tb_writer,
                global_vars_dict=global_vars_dict)

        if args.val_with_flow_gt:
            pass
            #flow_errors, flow_error_names = validate_flow_with_gt(val_flow_loader, disp_net, pose_net, mask_net, flow_net, epoch, logger, tb_writer)

            #for error, name in zip(flow_errors, flow_error_names):
            #    training_writer.add_scalar(name, error, epoch)

        #----------------------

        #3.4 Up to you to chose the most relevant error to measure your model's performance, careful some measures are to maximize (such as a1,a2,a3)

        if not args.fix_posenet:
            decisive_error = 0  # flow_errors[-2]    # epe_rigid_with_gt_mask
        elif not args.fix_dispnet:
            decisive_error = 0  # errors[0]      #depth abs_diff
        elif not args.fix_flownet:
            decisive_error = 0  # flow_errors[-1]    #epe_non_rigid_with_gt_mask
        elif not args.fix_masknet:
            decisive_error = 0  #flow_errors[3]     # percent outliers

        #3.5 log
        if args.log_terminal:
            logger.train_writer.write(
                ' * Avg Loss : {:.3f}'.format(train_loss))
            logger.reset_valid_bar()
        #eopch data log on tensorboard
        #train loss
        tb_writer.add_scalar('epoch/train_loss', train_loss, epoch)
        #val_without_gt loss
        if args.val_without_gt:
            tb_writer.add_scalar('epoch/val_loss', val_loss, epoch)

        if args.val_with_depth_gt:
            #val with depth gt
            for error, name in zip(depth_errors, depth_error_names):
                tb_writer.add_scalar('epoch/' + name, error, epoch)

        #3.6 save model and remember lowest error and save checkpoint

        if best_error < 0:
            best_error = train_loss

        is_best = train_loss <= best_error
        best_error = min(best_error, train_loss)
        save_checkpoint(args.save_path, {
            'epoch': epoch + 1,
            'state_dict': disp_net.module.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': pose_net.module.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': mask_net.module.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': flow_net.module.state_dict()
        }, is_best)

        with open(args.save_path / args.log_summary, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([train_loss, decisive_error])
    if args.log_terminal:
        logger.epoch_bar.finish()