def train_dino(args): utils.init_distributed_mode(args) utils.fix_random_seeds(args.seed) print("git:\n {}\n".format(utils.get_sha())) print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))) cudnn.benchmark = True # ============ preparing data ... ============ transform = DataAugmentationDINO( args.global_crops_scale, args.local_crops_scale, args.local_crops_number, ) #dataset = datasets.ImageFolder(args.data_path, transform=transform) from sen12ms import get_transform dataset = AllSen12MSDataset(args.data_path, "train", transform=transform, tansform_coord=None, classes=None, seasons=None, split_by_region=True, download=False) sampler = torch.utils.data.DistributedSampler(dataset, shuffle=True) data_loader = torch.utils.data.DataLoader( dataset, sampler=sampler, batch_size=args.batch_size_per_gpu, num_workers=args.num_workers, pin_memory=True, drop_last=True, ) print(f"Data loaded: there are {len(dataset)} images.") # ============ building student and teacher networks ... ============ # if the network is a vision transformer (i.e. deit_tiny, deit_small, vit_base) if args.arch in vits.__dict__.keys(): student = vits.__dict__[args.arch]( patch_size=args.patch_size, drop_path_rate=0.1, # stochastic depth ) teacher = vits.__dict__[args.arch](patch_size=args.patch_size) embed_dim = student.embed_dim student = utils.replace_input_layer(student, inchannels=13) teacher = utils.replace_input_layer(teacher, inchannels=13) # otherwise, we check if the architecture is in torchvision models elif args.arch in torchvision_models.__dict__.keys(): student = torchvision_models.__dict__[args.arch]() teacher = torchvision_models.__dict__[args.arch]() embed_dim = student.fc.weight.shape[1] else: print(f"Unknow architecture: {args.arch}") # multi-crop wrapper handles forward with inputs of different resolutions student = utils.MultiCropWrapper( student, DINOHead( embed_dim, args.out_dim, use_bn=args.use_bn_in_head, norm_last_layer=args.norm_last_layer, )) teacher = utils.MultiCropWrapper( teacher, DINOHead(embed_dim, args.out_dim, args.use_bn_in_head), ) # move networks to gpu student, teacher = student.cuda(), teacher.cuda() # synchronize batch norms (if any) if utils.has_batchnorms(student): student = nn.SyncBatchNorm.convert_sync_batchnorm(student) teacher = nn.SyncBatchNorm.convert_sync_batchnorm(teacher) # we need DDP wrapper to have synchro batch norms working... teacher = nn.parallel.DistributedDataParallel(teacher, device_ids=[args.gpu]) teacher_without_ddp = teacher.module else: # teacher_without_ddp and teacher are the same thing teacher_without_ddp = teacher student = nn.parallel.DistributedDataParallel(student, device_ids=[args.gpu]) # teacher and student start with the same weights teacher_without_ddp.load_state_dict(student.module.state_dict()) # there is no backpropagation through the teacher, so no need for gradients for p in teacher.parameters(): p.requires_grad = False print(f"Student and Teacher are built: they are both {args.arch} network.") # ============ preparing loss ... ============ dino_loss = DINOLoss( args.out_dim, args.local_crops_number + 2, # total number of crops = 2 global crops + local_crops_number args.warmup_teacher_temp, args.teacher_temp, args.warmup_teacher_temp_epochs, args.epochs, ).cuda() # ============ preparing optimizer ... ============ params_groups = utils.get_params_groups(student) if args.optimizer == "adamw": optimizer = torch.optim.AdamW(params_groups) # to use with ViTs elif args.optimizer == "sgd": optimizer = torch.optim.SGD(params_groups, lr=0, momentum=0.9) # lr is set by scheduler elif args.optimizer == "lars": optimizer = utils.LARS( params_groups) # to use with convnet and large batches # for mixed precision training fp16_scaler = None if args.use_fp16: fp16_scaler = torch.cuda.amp.GradScaler() # ============ init schedulers ... ============ lr_schedule = utils.cosine_scheduler( args.lr * (args.batch_size_per_gpu * utils.get_world_size()) / 256., # linear scaling rule args.min_lr, args.epochs, len(data_loader), warmup_epochs=args.warmup_epochs, ) wd_schedule = utils.cosine_scheduler( args.weight_decay, args.weight_decay_end, args.epochs, len(data_loader), ) # momentum parameter is increased to 1. during training with a cosine schedule momentum_schedule = utils.cosine_scheduler(args.momentum_teacher, 1, args.epochs, len(data_loader)) print(f"Loss, optimizer and schedulers ready.") # ============ optionally resume training ... ============ to_restore = {"epoch": 0} utils.restart_from_checkpoint( os.path.join(args.output_dir, "checkpoint.pth"), run_variables=to_restore, student=student, teacher=teacher, optimizer=optimizer, fp16_scaler=fp16_scaler, dino_loss=dino_loss, ) start_epoch = to_restore["epoch"] start_time = time.time() print("Starting DINO training !") for epoch in range(start_epoch, args.epochs): data_loader.sampler.set_epoch(epoch) # ============ training one epoch of DINO ... ============ train_stats = train_one_epoch(student, teacher, teacher_without_ddp, dino_loss, data_loader, optimizer, lr_schedule, wd_schedule, momentum_schedule, epoch, fp16_scaler, args) # ============ writing logs ... ============ save_dict = { 'student': student.state_dict(), 'teacher': teacher.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch + 1, 'args': args, 'dino_loss': dino_loss.state_dict(), } if fp16_scaler is not None: save_dict['fp16_scaler'] = fp16_scaler.state_dict() utils.save_on_master(save_dict, os.path.join(args.output_dir, 'checkpoint.pth')) if args.saveckp_freq and epoch % args.saveckp_freq == 0: utils.save_on_master( save_dict, os.path.join(args.output_dir, f'checkpoint{epoch:04}.pth')) log_stats = { **{f'train_{k}': v for k, v in train_stats.items()}, 'epoch': epoch } if utils.is_main_process(): with (Path(args.output_dir) / "log.txt").open("a") as f: f.write(json.dumps(log_stats) + "\n") total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str))
def __init__(self, student, teacher, length, val_loader, embed_dim, args): super().__init__() # self.save_hyperparameters() self.ratio = args.ratio teacher.load_state_dict(student.state_dict()) self.student = NetWrapper(student, embed_dim, args, True) self.teacher = NetWrapper(teacher, embed_dim, args) self.teacher.projector.load_state_dict( self.student.projector[-1].state_dict()) for p in self.teacher.parameters(): p.requires_grad = False print( f"Student and Teacher are built: they are both {args.arch} network." ) # ============ preparing optimizer ... ============ params_groups = utils.get_params_groups(self.student) if args.optimizer == "adamw": self.optimizer = torch.optim.AdamW( params_groups) # to use with ViTs elif args.optimizer == "sgd": self.optimizer = torch.optim.SGD( params_groups, lr=0, momentum=0.9) # lr is set by scheduler elif args.optimizer == "lars": self.optimizer = utils.LARS( params_groups) # to use with convnet and large batches length = math.ceil(length / (args.accumulate * torch.cuda.device_count())) # ============ init schedulers ... ============ self.lr_schedule = utils.cosine_scheduler( args.lr * (args.accumulate * args.batch_size_per_gpu * torch.cuda.device_count()) / 256., # linear scaling rule args.min_lr * (args.accumulate * args.batch_size_per_gpu * torch.cuda.device_count()) / 256., args.epochs, length, warmup_epochs=args.warmup_epochs, ) self.wd_schedule = utils.cosine_scheduler( args.weight_decay, args.weight_decay_end, args.epochs, length, ) # print(length) # momentum parameter is increased to 1. during training with a cosine schedule self.momentum_schedule = utils.cosine_scheduler( args.momentum_teacher, 1, args.epochs, length) print(f"Loss, optimizer and schedulers ready.") self.val_loader = val_loader self.aug1 = torch.nn.Sequential( RandomApply(T.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.3), T.RandomGrayscale(p=0.2), T.RandomHorizontalFlip(), RandomApply(T.GaussianBlur((3, 3), (1.0, 2.0)), p=0.2), T.RandomResizedCrop((image_size, image_size)), T.Normalize(mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225])), ) self.aug2 = self.aug1