def main(device, args): train_directory = '../data/train' image_name_file = '../data/original.csv' val_directory = '../data/train' train_loader = torch.utils.data.DataLoader( dataset=get_dataset('random', train_directory, image_name_file, transform=get_aug(train=True, **args.aug_kwargs), train=True, **args.dataset_kwargs), # dataset=datasets.ImageFolder(root=train_directory, transform=get_aug(train=True, **args.aug_kwargs)), shuffle=True, batch_size=args.train.batch_size, **args.dataloader_kwargs ) memory_loader = torch.utils.data.DataLoader( dataset=datasets.ImageFolder(root=val_directory, transform=get_aug(train=False, train_classifier=False, **args.aug_kwargs)), shuffle=False, batch_size=args.train.batch_size, **args.dataloader_kwargs ) test_loader = torch.utils.data.DataLoader( dataset=datasets.ImageFolder(root=val_directory, transform=get_aug(train=False, train_classifier=False, **args.aug_kwargs)), shuffle=False, batch_size=args.train.batch_size, **args.dataloader_kwargs ) # define model model = get_model(args.model).to(device) model = torch.nn.DataParallel(model) scaler = torch.cuda.amp.GradScaler() # define optimizer optimizer = get_optimizer( args.train.optimizer.name, model, lr=args.train.base_lr * args.train.batch_size / 256, momentum=args.train.optimizer.momentum, weight_decay=args.train.optimizer.weight_decay) lr_scheduler = LR_Scheduler( optimizer, args.train.warmup_epochs, args.train.warmup_lr * args.train.batch_size / 256, args.train.num_epochs, args.train.base_lr * args.train.batch_size / 256, args.train.final_lr * args.train.batch_size / 256, len(train_loader), constant_predictor_lr=True # see the end of section 4.2 predictor ) RESUME = False start_epoch = 0 if RESUME: model = get_backbone(args.model.backbone) classifier = nn.Linear(in_features=model.output_dim, out_features=9, bias=True).to(args.device) assert args.eval_from is not None save_dict = torch.load(args.eval_from, map_location='cpu') msg = model.load_state_dict({k[9:]: v for k, v in save_dict['state_dict'].items() if k.startswith('backbone.')}, strict=True) path_checkpoint = "./checkpoint/simsiam-TCGA-0218-nearby_0221134812.pth" # 断点路径 checkpoint = torch.load(path_checkpoint) # 加载断点 model.load_state_dict(checkpoint['net']) # 加载模型可学习参数 optimizer.load_state_dict(checkpoint['optimizer']) # 加载优化器参数 start_epoch = checkpoint['epoch'] # 设置开始的epoch logger = Logger(tensorboard=args.logger.tensorboard, matplotlib=args.logger.matplotlib, log_dir=args.log_dir) accuracy = 0 # Start training global_progress = tqdm(range(start_epoch, args.train.stop_at_epoch), desc=f'Training') for epoch in global_progress: model.train() local_progress = tqdm(train_loader, desc=f'Epoch {epoch}/{args.train.num_epochs}', disable=args.hide_progress) for idx, (images1, images2, images3, labels) in enumerate(local_progress): model.zero_grad() with torch.cuda.amp.autocast(): data_dict = model.forward(images1.to(device, non_blocking=True), images2.to(device, non_blocking=True), images3.to(device, non_blocking=True)) loss = data_dict['loss'].mean() # ddp # loss.backward() scaler.scale(loss).backward() # optimizer.step() scaler.step(optimizer) scaler.update() lr_scheduler.step() data_dict.update({'lr': lr_scheduler.get_lr()}) local_progress.set_postfix(data_dict) logger.update_scalers(data_dict) if args.train.knn_monitor and epoch % args.train.knn_interval == 0: accuracy = knn_monitor(model.module.backbone, memory_loader, test_loader, device, k=min(args.train.knn_k, len(memory_loader.dataset)), hide_progress=args.hide_progress) epoch_dict = {"epoch": epoch, "accuracy": accuracy} global_progress.set_postfix(epoch_dict) logger.update_scalers(epoch_dict) checkpoint = { "net": model.state_dict(), 'optimizer': optimizer.state_dict(), "epoch": epoch } if (epoch % args.train.save_interval) == 0: torch.save({ 'epoch': epoch + 1, 'state_dict': model.module.state_dict() }, './checkpoint/exp_0223_triple_400_proj3/ckpt_best_%s.pth' % (str(epoch))) # Save checkpoint model_path = os.path.join(args.ckpt_dir, f"{args.name}_{datetime.now().strftime('%m%d%H%M%S')}.pth") # datetime.now().strftime('%Y%m%d_%H%M%S') torch.save({ 'epoch': epoch + 1, 'state_dict': model.module.state_dict() }, model_path) print(f"Model saved to {model_path}") with open(os.path.join(args.log_dir, f"checkpoint_path.txt"), 'w+') as f: f.write(f'{model_path}') if args.eval is not False: args.eval_from = model_path linear_eval(args)
def main(device, args): train_loader = torch.utils.data.DataLoader( dataset=get_dataset(transform=get_aug(train=True, **args.aug_kwargs), train=True, **args.dataset_kwargs), shuffle=True, batch_size=args.train.batch_size, **args.dataloader_kwargs) memory_loader = torch.utils.data.DataLoader( dataset=get_dataset(transform=get_aug(train=False, train_classifier=False, **args.aug_kwargs), train=True, **args.dataset_kwargs), shuffle=False, batch_size=args.train.batch_size, **args.dataloader_kwargs) test_loader = torch.utils.data.DataLoader(dataset=get_dataset( transform=get_aug(train=False, train_classifier=False, **args.aug_kwargs), train=False, **args.dataset_kwargs), shuffle=False, batch_size=args.train.batch_size, **args.dataloader_kwargs) # define model model = get_model(args.model).to(device) model = torch.nn.DataParallel(model) # define optimizer optimizer = get_optimizer(args.train.optimizer.name, model, lr=args.train.base_lr * args.train.batch_size / 256, momentum=args.train.optimizer.momentum, weight_decay=args.train.optimizer.weight_decay) lr_scheduler = LR_Scheduler( optimizer, args.train.warmup_epochs, args.train.warmup_lr * args.train.batch_size / 256, args.train.num_epochs, args.train.base_lr * args.train.batch_size / 256, args.train.final_lr * args.train.batch_size / 256, len(train_loader), constant_predictor_lr=True # see the end of section 4.2 predictor ) logger = Logger(tensorboard=args.logger.tensorboard, matplotlib=args.logger.matplotlib, log_dir=args.log_dir) accuracy = 0 # Start training print("Trying to train model {}".format(model)) print("Will run up to {} epochs of training".format( args.train.stop_at_epoch)) global_progress = tqdm(range(0, args.train.stop_at_epoch), desc=f'Training') for epoch in global_progress: model.train() local_progress = tqdm(train_loader, desc=f'Epoch {epoch}/{args.train.num_epochs}', disable=args.hide_progress) for idx, _data in enumerate(local_progress): # TODO looks like we might be missing the label? ((images1, images2), labels) = _data model.zero_grad() data_dict = model.forward(images1.to(device, non_blocking=True), images2.to(device, non_blocking=True)) loss = data_dict['loss'].mean() # ddp loss.backward() optimizer.step() lr_scheduler.step() data_dict.update({'lr': lr_scheduler.get_lr()}) local_progress.set_postfix(data_dict) logger.update_scalers(data_dict) # ignore KNN monitor since it's coded to work ONLY on cuda enabled devices unfortunately # check the mnist yaml to see if args.train.knn_monitor and epoch % args.train.knn_interval == 0: accuracy = knn_monitor(model.module.backbone, memory_loader, test_loader, device, k=min(args.train.knn_k, len(memory_loader.dataset)), hide_progress=args.hide_progress) epoch_dict = {"epoch": epoch, "accuracy": accuracy} global_progress.set_postfix(epoch_dict) logger.update_scalers(epoch_dict) # Save checkpoint model_path = os.path.join( args.ckpt_dir, f"{args.name}_{datetime.now().strftime('%m%d%H%M%S')}.pth" ) # datetime.now().strftime('%Y%m%d_%H%M%S') torch.save({ 'epoch': epoch + 1, 'state_dict': model.module.state_dict() }, model_path) print(f"Model saved to {model_path}") with open(os.path.join(args.log_dir, f"checkpoint_path.txt"), 'w+') as f: f.write(f'{model_path}') if args.eval is not False: args.eval_from = model_path linear_eval(args)
def main(gpu, args): rank = args.nr * args.gpus + gpu dist.init_process_group("nccl", rank=rank, world_size=args.world_size) torch.manual_seed(0) torch.cuda.set_device(gpu) train_dataset = get_dataset(transform=get_aug(train=True, **args.aug_kwargs), train=True, **args.dataset_kwargs) train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=args.world_size, rank=rank) train_loader = torch.utils.data.DataLoader( dataset=train_dataset, shuffle=False, batch_size=(args.train.batch_size // args.gpus), sampler=train_sampler, **args.dataloader_kwargs) memory_dataset = get_dataset(transform=get_aug(train=False, train_classifier=False, **args.aug_kwargs), train=True, **args.dataset_kwargs) memory_loader = torch.utils.data.DataLoader( dataset=memory_dataset, shuffle=False, batch_size=(args.train.batch_size // args.gpus), **args.dataloader_kwargs) test_datset = get_dataset(transform=get_aug(train=False, train_classifier=False, **args.aug_kwargs), train=False, **args.dataset_kwargs) test_loader = torch.utils.data.DataLoader( dataset=test_datset, shuffle=False, batch_size=(args.train.batch_size // args.gpus), **args.dataloader_kwargs) print("Batch size:", (args.train.batch_size // args.gpus)) # define model model = get_model(args.model).cuda(gpu) model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model = DDP(model, device_ids=[gpu], find_unused_parameters=True) # define optimizer optimizer = get_optimizer(args.train.optimizer.name, model, lr=args.train.base_lr * args.train.batch_size / 256, momentum=args.train.optimizer.momentum, weight_decay=args.train.optimizer.weight_decay) lr_scheduler = LR_Scheduler( optimizer, args.train.warmup_epochs, args.train.warmup_lr * args.train.batch_size / 256, args.train.num_epochs, args.train.base_lr * args.train.batch_size / 256, args.train.final_lr * args.train.batch_size / 256, len(train_loader), constant_predictor_lr=True # see the end of section 4.2 predictor ) if gpu == 0: logger = Logger(tensorboard=args.logger.tensorboard, matplotlib=args.logger.matplotlib, log_dir=args.log_dir) accuracy = 0 # Start training global_progress = tqdm(range(0, args.train.stop_at_epoch), desc=f'Training') for epoch in global_progress: model.train() local_progress = tqdm(train_loader, desc=f'Epoch {epoch}/{args.train.num_epochs}', disable=args.hide_progress) for idx, ((images1, images2), labels) in enumerate(local_progress): model.zero_grad() data_dict = model.forward(images1.cuda(non_blocking=True), images2.cuda(non_blocking=True)) loss = data_dict['loss'] # ddp loss.backward() optimizer.step() lr_scheduler.step() data_dict.update({'lr': lr_scheduler.get_lr()}) local_progress.set_postfix(data_dict) if gpu == 0: logger.update_scalers(data_dict) if args.train.knn_monitor and epoch % args.train.knn_interval == 0 and gpu == 0: accuracy = knn_monitor(model.module.backbone, memory_loader, test_loader, gpu, k=min(args.train.knn_k, len(memory_loader.dataset)), hide_progress=args.hide_progress) epoch_dict = {"epoch": epoch, "accuracy": accuracy} global_progress.set_postfix(epoch_dict) if gpu == 0: logger.update_scalers(epoch_dict) # Save checkpoint if gpu == 0 and epoch % args.train.knn_interval == 0: model_path = os.path.join( args.ckpt_dir, f"{args.name}_{epoch+1}.pth" ) # datetime.now().strftime('%Y%m%d_%H%M%S') torch.save( { 'epoch': epoch + 1, 'state_dict': model.module.state_dict() }, model_path) print(f"Model saved to {model_path}") with open(os.path.join(args.log_dir, f"checkpoint_path.txt"), 'w+') as f: f.write(f'{model_path}') # if args.eval is not False and gpu == 0: # args.eval_from = model_path # linear_eval(args) dist.destroy_process_group()