def __init__(self, device, train_loader, val_loader, model, MODEL_PATH): self.device = device self.model_path = MODEL_PATH self.train_loader = train_loader self.val_loader = val_loader self.model = model.to(device) base_optimizer = optim.Adam(self.model.parameters(), lr=2e-4) self.optimizer = LARS(optimizer=base_optimizer, eps=1e-8, trust_coef=0.001) self.criterion = nn.BCEWithLogitsLoss() self.scaler = GradScaler()
def test_pickle(): sgd = SGD([torch.Tensor()], lr=0.1) lars = LARS(sgd, trust_coef=0.42) lars2 = pickle.loads(pickle.dumps(lars)) assert isinstance(lars2.optim, SGD) assert lars2.optim.param_groups[0]['lr'] == 0.1 assert lars2.trust_coef == 0.42
def test_deepcopy(): sgd = SGD([torch.Tensor()], lr=0.1) lars = LARS(sgd, trust_coef=0.42) lars2 = copy.deepcopy(lars) assert isinstance(lars2.optim, SGD) assert lars2.optim.param_groups[0]['lr'] == 0.1 assert lars2.trust_coef == 0.42
model = DocumentReader(args.hidden_dim, args.emb_dim, args.layers, NUM_DIRECTIONS, args.dropout, device).to(device) model = torch.nn.DataParallel(model) writer = SummaryWriter(comment="_nlp_%s_%s_%s" % (args.optimizer, args.batch_size, args.learning_rate)) weight_decay = args.learning_rate / args.epochs if args.optimizer == 'lamb': optimizer = Lamb(model.parameters(), lr=args.learning_rate, weight_decay=weight_decay, betas=(.9, .999), adam=False, writer=writer) elif args.optimizer == 'lars': base_optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=weight_decay) optimizer = LARS(optimizer=base_optimizer, eps=1e-8, trust_coef=0.001, writer=writer) elif args.optimizer == 'sgd': optimizer = SGD(model.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=weight_decay, writer=writer) else: # use adam optimizer optimizer = Lamb(model.parameters(), lr=args.learning_rate, weight_decay=weight_decay, betas=(.9, .999), adam=True, writer=writer) print(f'The model has {count_parameters(model):,} trainable parameters') ckpt_dir_name = "%s_%s_%s" % (args.working_dir, args.optimizer, args.batch_size) model, optimizer = load_pretrained_model(model, optimizer, "%s/ckpt/%s" % (ckpt_dir_name, "best_weights.pt")) print(args) ckpt_dir = os.path.join(ckpt_dir_name, 'ckpt') os.makedirs(ckpt_dir, exist_ok=True)
def main(): args = parse_option() if args.gpu is not None: print("Use GPU: {} for training".format(args.gpu)) # set the data loader data_folder = os.path.join(args.data_folder, 'train') val_folder = os.path.join(args.data_folder, 'val') crop_padding = 32 image_size = 224 mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] normalize = transforms.Normalize(mean=mean, std=std) if args.aug == 'NULL' and args.dataset == 'imagenet': train_transform = transforms.Compose([ transforms.RandomResizedCrop(image_size, scale=(args.crop, 1.)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ]) elif args.aug == 'CJ': train_transform = transforms.Compose([ transforms.RandomResizedCrop(image_size, scale=(args.crop, 1.)), transforms.RandomGrayscale(p=0.2), transforms.ColorJitter(0.4, 0.4, 0.4, 0.4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ]) # elif args.aug == 'NULL' and args.dataset == 'cifar': # train_transform = transforms.Compose([ # transforms.RandomResizedCrop(size=32, scale=(0.2, 1.)), # transforms.ColorJitter(0.4, 0.4, 0.4, 0.4), # transforms.RandomGrayscale(p=0.2), # transforms.RandomHorizontalFlip(p=0.5), # transforms.ToTensor(), # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), # ]) # # test_transform = transforms.Compose([ # transforms.ToTensor(), # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), # ]) elif args.aug == 'simple' and args.dataset == 'imagenet': train_transform = transforms.Compose([ transforms.RandomResizedCrop(image_size, scale=(args.crop, 1.)), transforms.RandomHorizontalFlip(), get_color_distortion(1.0), transforms.ToTensor(), normalize, ]) # TODO: Currently follow CMC test_transform = transforms.Compose([ transforms.Resize(image_size + crop_padding), transforms.CenterCrop(image_size), transforms.ToTensor(), normalize, ]) elif args.aug == 'simple' and args.dataset == 'cifar': train_transform = transforms.Compose([ transforms.RandomResizedCrop(size=32), transforms.RandomHorizontalFlip(p=0.5), get_color_distortion(0.5), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) test_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) else: raise NotImplemented('augmentation not supported: {}'.format(args.aug)) # Get Datasets if args.dataset == "imagenet": train_dataset = ImageFolderInstance(data_folder, transform=train_transform, two_crop=args.moco) print(len(train_dataset)) train_sampler = None train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.num_workers, pin_memory=True, sampler=train_sampler) test_dataset = datasets.ImageFolder(val_folder, transforms=test_transform) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=args.num_workers, pin_memory=True) elif args.dataset == 'cifar': # cifar-10 dataset if args.contrastive_model == 'simclr': train_dataset = CIFAR10Instance_double(root='./data', train=True, download=True, transform=train_transform, double=True) else: train_dataset = CIFAR10Instance(root='./data', train=True, download=True, transform=train_transform) train_sampler = None train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.num_workers, pin_memory=True, sampler=train_sampler, drop_last=True) test_dataset = CIFAR10Instance(root='./data', train=False, download=True, transform=test_transform) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=100, shuffle=False, num_workers=args.num_workers) # create model and optimizer n_data = len(train_dataset) if args.model == 'resnet50': model = InsResNet50() if args.contrastive_model == 'moco': model_ema = InsResNet50() elif args.model == 'resnet50x2': model = InsResNet50(width=2) if args.contrastive_model == 'moco': model_ema = InsResNet50(width=2) elif args.model == 'resnet50x4': model = InsResNet50(width=4) if args.contrastive_model == 'moco': model_ema = InsResNet50(width=4) elif args.model == 'resnet50_cifar': model = InsResNet50_cifar() if args.contrastive_model == 'moco': model_ema = InsResNet50_cifar() else: raise NotImplementedError('model not supported {}'.format(args.model)) # copy weights from `model' to `model_ema' if args.contrastive_model == 'moco': moment_update(model, model_ema, 0) # set the contrast memory and criterion if args.contrastive_model == 'moco': contrast = MemoryMoCo(128, n_data, args.nce_k, args.nce_t, args.softmax).cuda(args.gpu) elif args.contrastive_model == 'simclr': contrast = None else: contrast = MemoryInsDis(128, n_data, args.nce_k, args.nce_t, args.nce_m, args.softmax).cuda(args.gpu) if args.softmax: criterion = NCESoftmaxLoss() elif args.contrastive_model == 'simclr': criterion = BatchCriterion(1, args.nce_t, args.batch_size) else: criterion = NCECriterion(n_data) criterion = criterion.cuda(args.gpu) model = model.cuda() if args.contrastive_model == 'moco': model_ema = model_ema.cuda() # Exclude BN and bias if needed weight_decay = args.weight_decay if weight_decay and args.filter_weight_decay: parameters = add_weight_decay(model, weight_decay, args.filter_weight_decay) weight_decay = 0. else: parameters = model.parameters() optimizer = torch.optim.SGD(parameters, lr=args.learning_rate, momentum=args.momentum, weight_decay=weight_decay) cudnn.benchmark = True if args.amp: model, optimizer = amp.initialize(model, optimizer, opt_level=args.opt_level) if args.contrastive_model == 'moco': optimizer_ema = torch.optim.SGD(model_ema.parameters(), lr=0, momentum=0, weight_decay=0) model_ema, optimizer_ema = amp.initialize(model_ema, optimizer_ema, opt_level=args.opt_level) if args.LARS: optimizer = LARS(optimizer=optimizer, eps=1e-8, trust_coef=0.001) # optionally resume from a checkpoint args.start_epoch = 0 if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume, map_location='cpu') # checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] + 1 model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) if contrast: contrast.load_state_dict(checkpoint['contrast']) if args.contrastive_model == 'moco': model_ema.load_state_dict(checkpoint['model_ema']) if args.amp and checkpoint['opt'].amp: print('==> resuming amp state_dict') amp.load_state_dict(checkpoint['amp']) print("=> loaded successfully '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) del checkpoint torch.cuda.empty_cache() else: print("=> no checkpoint found at '{}'".format(args.resume)) # tensorboard logger = tb_logger.Logger(logdir=args.tb_folder, flush_secs=2) # routine for epoch in range(args.start_epoch, args.epochs + 1): print("==> training...") time1 = time.time() if args.contrastive_model == 'moco': loss, prob = train_moco(epoch, train_loader, model, model_ema, contrast, criterion, optimizer, args) elif args.contrastive_model == 'simclr': print("Train using simclr") loss, prob = train_simclr(epoch, train_loader, model, criterion, optimizer, args) else: print("Train using InsDis") loss, prob = train_ins(epoch, train_loader, model, contrast, criterion, optimizer, args) time2 = time.time() print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1)) # tensorboard logger logger.log_value('ins_loss', loss, epoch) logger.log_value('ins_prob', prob, epoch) logger.log_value('learning_rate', optimizer.param_groups[0]['lr'], epoch) test_epoch = 2 if epoch % test_epoch == 0: model.eval() if args.contrastive_model == 'moco': model_ema.eval() print('----------Evaluation---------') start = time.time() if args.dataset == 'cifar': acc = kNN(epoch, model, train_loader, test_loader, 200, args.nce_t, n_data, low_dim=128, memory_bank=None) print("Evaluation Time: '{}'s".format(time.time() - start)) # writer.add_scalar('nn_acc', acc, epoch) logger.log_value('Test accuracy', acc, epoch) # print('accuracy: {}% \t (best acc: {}%)'.format(acc, best_acc)) print('[Epoch]: {}'.format(epoch)) print('accuracy: {}%)'.format(acc)) # test_log_file.flush() # save model if epoch % args.save_freq == 0: print('==> Saving...') state = { 'opt': args, 'model': model.state_dict(), # 'contrast': contrast.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch, } if args.contrastive_model == 'moco': state['model_ema'] = model_ema.state_dict() if args.amp: state['amp'] = amp.state_dict() save_file = os.path.join( args.model_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch)) torch.save(state, save_file) # help release GPU memory del state # saving the model print('==> Saving...') state = { 'opt': args, 'model': model.state_dict(), # 'contrast': contrast.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch, } if args.contrastive_model == 'moco': state['model_ema'] = model_ema.state_dict() if args.amp: state['amp'] = amp.state_dict() save_file = os.path.join(args.model_folder, 'current.pth') torch.save(state, save_file) if epoch % args.save_freq == 0: save_file = os.path.join( args.model_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch)) torch.save(state, save_file) # help release GPU memory del state torch.cuda.empty_cache()
def run_main(self): """""" """Fix the random seed""" np.random.seed(self.seed) torch.manual_seed(self.seed) """RandAugment parameters""" if self.flag_randaug == 1: if self.rand_n == 0 and self.rand_m == 0: if self.n_model == 'ResNet': self.rand_n = 2 self.rand_m = 9 elif self.n_model == 'WideResNet': if self.n_data == 'CIFAR-10': self.rand_n = 3 self.rand_m = 5 elif self.n_data == 'CIFAR-100': self.rand_n = 2 self.rand_m = 14 elif self.n_data == 'SVHN': self.rand_n = 3 self.rand_m = 7 """Dataset""" traintest_dataset = dataset.MyDataset_training( n_data=self.n_data, num_data=self.num_training_data, seed=self.seed, flag_randaug=self.flag_randaug, rand_n=self.rand_n, rand_m=self.rand_m, cutout=self.cutout) self.num_channel, self.num_classes, self.size_after_cnn, self.input_size, self.hidden_size = traintest_dataset.get_info( n_data=self.n_data) n_samples = len(traintest_dataset) if self.num_training_data == 0: self.num_training_data = n_samples if self.flag_traintest == 1: # train_size = self.num_classes * 100 train_size = int(n_samples * 0.65) test_size = n_samples - train_size train_dataset, test_dataset = torch.utils.data.random_split( traintest_dataset, [train_size, test_size]) train_sampler = None test_sampler = None else: train_dataset = traintest_dataset test_dataset = dataset.MyDataset_test(n_data=self.n_data) train_sampler = train_dataset.sampler test_sampler = test_dataset.sampler num_workers = 16 train_shuffle = True test_shuffle = False if train_sampler: train_shuffle = False if test_sampler: test_shuffle = False self.train_loader = torch.utils.data.DataLoader( dataset=train_dataset, batch_size=self.batch_size_training, sampler=train_sampler, shuffle=train_shuffle, num_workers=num_workers, pin_memory=True) self.test_loader = torch.utils.data.DataLoader( dataset=test_dataset, batch_size=self.batch_size_test, sampler=test_sampler, shuffle=test_shuffle, num_workers=num_workers, pin_memory=True) """Transfer learning""" if self.flag_transfer == 1: pretrained = True num_classes = 1000 else: pretrained = False num_classes = self.num_classes """Neural network model""" model = None if self.n_model == 'CNN': model = cnn.ConvNet(num_classes=self.num_classes, num_channel=self.num_channel, size_after_cnn=self.size_after_cnn, n_aug=self.n_aug) elif self.n_model == 'ResNet': model = resnet.ResNet(n_data=self.n_data, depth=50, num_classes=self.num_classes, num_channel=self.num_channel, n_aug=self.n_aug, bottleneck=True) # resnet50 # model = resnet.ResNet(n_data=self.n_data, depth=200, num_classes=self.num_classes, num_channel=self.num_channel, bottleneck=True) # resnet200 elif self.n_model == 'WideResNet': # model = resnet.wide_resnet50_2(num_classes=self.num_classes, num_channel=self.num_channel) # model = WideResNet(depth=40, iden_factor=2, dropout_rate=0.0, num_classes=num_class, num_channel=self.num_channel) # wresnet40_2 model = wideresnet.WideResNet(depth=28, widen_factor=10, dropout_rate=0.0, num_classes=self.num_classes, num_channel=self.num_channel, n_aug=self.n_aug) # wresnet28_10 print(torch.__version__) """Transfer learning""" if self.flag_transfer == 1: for param in model.parameters(): param.requires_grad = False num_features = model.fc.in_features model.fc = nn.Linear(num_features, self.num_classes) """Show paramters""" if self.show_params == 1: params = 0 for p in model.parameters(): if p.requires_grad: params += p.numel() print(params) """GPU setting""" device = 'cuda' if torch.cuda.is_available() else 'cpu' model = model.to(device) if device == 'cuda': if self.gpu_multi == 1: model = torch.nn.DataParallel(model) torch.backends.cudnn.benchmark = True print('GPU={}'.format(torch.cuda.device_count())) """Loss function""" if self.lb_smooth > 0.0: criterion = objective.SmoothCrossEntropyLoss(self.lb_smooth) else: criterion = objective.SoftCrossEntropy() """Optimizer""" optimizer = 0 if self.opt == 0: # Adam if self.flag_transfer == 1: optimizer = torch.optim.Adam(model.fc.parameters(), lr=0.001) else: optimizer = torch.optim.Adam(model.parameters(), lr=0.001) elif self.opt == 1: # SGD if self.n_model == 'ResNet': lr = 0.1 weight_decay = 0.0001 elif self.n_model == 'WideResNet': if self.n_data == 'SVHN': lr = 0.005 weight_decay = 0.001 else: lr = 0.1 weight_decay = 0.0005 else: lr = 0.1 weight_decay = 0.0005 optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay, nesterov=True) if self.flag_lars == 1: from torchlars import LARS optimizer = LARS(optimizer) """Learning rate scheduling""" scheduler = None if self.flag_lr_schedule == 2: scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=self.num_epochs, eta_min=0.) elif self.flag_lr_schedule == 3: if self.num_epochs == 90: scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, [30, 60, 80]) elif self.num_epochs == 180: scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, [60, 120, 160]) elif self.num_epochs == 270: scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, [90, 180, 240]) if self.flag_warmup == 1: if self.n_model == 'ResNet': multiplier = 2 total_epoch = 3 elif self.n_model == 'WideResNet': multiplier = 2 if self.n_data == 'SVHN': total_epoch = 3 else: total_epoch = 5 else: multiplier = 2 total_epoch = 3 scheduler = GradualWarmupScheduler(optimizer, multiplier=multiplier, total_epoch=total_epoch, after_scheduler=scheduler) """Initialize""" self.flag_noise = np.random.randint(0, 2, self.num_training_data) if self.flag_acc5 == 1: results = np.zeros((self.num_epochs, 6)) else: results = np.zeros((self.num_epochs, 5)) start_time = timeit.default_timer() t = 0 # fixed_interval = 10 fixed_interval = 1 loss_fixed_all = np.zeros(self.num_epochs // fixed_interval) self.loss_training_batch = np.zeros( int(self.num_epochs * np.ceil(self.num_training_data / self.batch_size_training))) for epoch in range(self.num_epochs): """Training""" model.train() start_epoch_time = timeit.default_timer( ) # Get the start time of this epoch loss_each_all = np.zeros(self.num_training_data) loss_training_all = 0 loss_test_all = 0 """Learning rate scheduling""" # if self.flag_lr_schedule == 1: # if self.num_epochs == 200: # if epoch == 100: # optimizer = torch.optim.Adam(model.parameters(), lr=0.0001) if self.flag_variance == 1: # when computing the variance of training loss if epoch % fixed_interval == 0: loss_fixed = np.zeros(self.num_training_data // self.batch_size_training) for i, (images, labels, index) in enumerate(self.train_loader): if np.array(images.data.cpu()).ndim == 3: images = images.reshape(images.shape[0], 1, images.shape[1], images.shape[2]).to(device) else: images = images.to(device) labels = labels.to(device) flag_onehot = 0 if self.flag_spa == 1: outputs_fixed = model.forward(images) if flag_onehot == 0: labels = np.identity(self.num_classes)[labels] else: outputs_fixed = model(images) labels = np.identity(self.num_classes)[np.array( labels.data.cpu())] labels = util.to_var(torch.from_numpy(labels).float()) loss_fixed[i] = criterion.forward( outputs_fixed, labels) loss_fixed_all[t] = np.var(loss_fixed) t += 1 total_steps = len(self.train_loader) steps = 0 num_training_data = 0 for i, (images, labels, index) in enumerate(self.train_loader): steps += 1 if np.array(images.data.cpu()).ndim == 3: images = images.reshape(images.shape[0], 1, images.shape[1], images.shape[2]).to(device) else: images = images.to(device) labels = labels.to(device) """Save images""" if self.save_images == 1: util.save_images(images) """Get training loss for each sample by inputting data before applying data augmentation""" if self.flag_spa == 1: outputs = model(images) labels_spa = labels.clone() if labels_spa.ndim == 1: labels_spa = torch.eye( self.num_classes, device='cuda')[labels_spa].clone() # To one-hot loss_each = criterion.forward_each_example( outputs, labels_spa) # Loss for each sample loss_each_all[index] = np.array( loss_each.data.cpu()) # Put losses for target samples self.flag_noise = util.flag_update( loss_each_all, self.judge_noise) # Update the noise flag """Forward propagation""" if self.flag_spa == 1: images, labels = util.self_paced_augmentation( images=images, labels=labels, flag_noise=self.flag_noise, index=np.array(index.data.cpu()), n_aug=self.n_aug, num_classes=self.num_classes) else: images, labels = util.run_n_aug( x=images, y=labels, n_aug=self.n_aug, num_classes=self.num_classes) outputs = model(images) if labels.ndim == 1: labels = torch.eye(self.num_classes, device='cuda')[labels].clone() loss_training = criterion.forward(outputs, labels) loss_training_all += loss_training.item() * outputs.shape[0] # self.loss_training_batch[int(i + epoch * np.ceil(self.num_training_data / self.batch_size_training))] = loss_training * outputs.shape[0] num_training_data += images.shape[0] """Back propagation and update""" optimizer.zero_grad() loss_training.backward() optimizer.step() """When changing flag_noise randomly""" """ if self.flag_spa == 1: outputs = model.forward(x=images) if labels.ndim == 1: y_soft = torch.eye(self.num_classes, device='cuda')[labels] # Convert to one-hot loss_each = criterion.forward_each_example(outputs, y_soft) # Loss for each sample loss_each_all[index] = np.array(loss_each.data.cpu()) # Put losses for target samples self.flag_noise = util.flag_update(loss_each_all, self.judge_noise) # Update the noise flag """ loss_training_each = loss_training_all / num_training_data # np.random.shuffle(self.flag_noise) """Test""" model.eval() with torch.no_grad(): if self.flag_acc5 == 1: top1 = list() top5 = list() else: correct = 0 total = 0 num_test_data = 0 for images, labels in self.test_loader: if np.array(images.data).ndim == 3: images = images.reshape(images.shape[0], 1, images.shape[1], images.shape[2]).to(device) else: images = images.to(device) labels = labels.to(device) outputs = model(x=images) if self.flag_acc5 == 1: acc1, acc5 = util.accuracy(outputs.data, labels.long(), topk=(1, 5)) top1.append(acc1[0].item()) top5.append(acc5[0].item()) else: _, predicted = torch.max(outputs.data, 1) correct += (predicted == labels.long()).sum().item() total += labels.size(0) if labels.ndim == 1: labels = torch.eye(self.num_classes, device='cuda')[labels] loss_test = criterion.forward(outputs, labels) loss_test_all += loss_test.item() * outputs.shape[0] num_test_data += images.shape[0] """Compute test results""" top1_avg = 0 top5_avg = 0 test_accuracy = 0 if self.flag_acc5 == 1: top1_avg = sum(top1) / float(len(top1)) top5_avg = sum(top5) / float(len(top5)) else: test_accuracy = 100.0 * correct / total loss_test_each = loss_test_all / num_test_data """Compute running time""" end_epoch_time = timeit.default_timer() epoch_time = end_epoch_time - start_epoch_time num_flag = np.sum(self.flag_noise == 1) """Learning rate scheduling""" if self.flag_lr_schedule > 1 and scheduler is not None: scheduler.step(epoch - 1 + float(steps) / total_steps) """Show results""" flag_log = 1 if flag_log == 1: if self.flag_acc5 == 1: print( 'Epoch [{}/{}], Train Loss: {:.4f}, Top1 Test Acc: {:.3f} %, Top5 Test Acc: {:.3f} %, Test Loss: {:.4f}, Epoch Time: {:.2f}s, Num_flag: {}' .format(epoch + 1, self.num_epochs, loss_training_each, top1_avg, top5_avg, loss_test_each, epoch_time, num_flag)) else: print( 'Epoch [{}/{}], Train Loss: {:.4f}, Test Acc: {:.3f} %, Test Loss: {:.4f}, Epoch Time: {:.2f}s, Num_flag: {}' .format(epoch + 1, self.num_epochs, loss_training_each, test_accuracy, loss_test_each, epoch_time, num_flag)) if self.flag_wandb == 1: wandb.log({"loss_training_each": loss_training_each}) wandb.log({"test_accuracy": test_accuracy}) wandb.log({"loss_test_each": loss_test_each}) wandb.log({"num_flag": num_flag}) wandb.log({"epoch_time": epoch_time}) if self.save_file == 1: if flag_log == 1: if self.flag_acc5 == 1: results[epoch][0] = loss_training_each results[epoch][1] = top1_avg results[epoch][2] = top5_avg results[epoch][3] = loss_test_each results[epoch][4] = num_flag results[epoch][5] = epoch_time else: results[epoch][0] = loss_training_each results[epoch][1] = test_accuracy results[epoch][2] = loss_test_each results[epoch][3] = num_flag results[epoch][4] = epoch_time end_time = timeit.default_timer() flag_log = 1 if flag_log == 1: print(' ran for %.4fm' % ((end_time - start_time) / 60.)) """Show accuracy""" top1_avg_max = np.max(results[:, 1]) print(top1_avg_max) if flag_log == 1 and self.flag_acc5 == 1: top5_avg_max = np.max(results[:, 2]) print(top5_avg_max) """Save files""" if self.save_file == 1: if self.flag_randaug == 1: np.savetxt( 'results/data_%s_model_%s_num_%s_randaug_%s_n_%s_m_%s_seed_%s_acc_%s.csv' % (self.n_data, self.n_model, self.num_training_data, self.flag_randaug, self.rand_n, self.rand_m, self.seed, top1_avg_max), results, delimiter=',') else: if self.flag_spa == 1: np.savetxt( 'results/data_%s_model_%s_judge_%s_aug_%s_num_%s_seed_%s_acc_%s.csv' % (self.n_data, self.n_model, self.judge_noise, self.n_aug, self.num_training_data, self.seed, top1_avg_max), results, delimiter=',') else: np.savetxt( 'results/data_%s_model_%s_aug_%s_num_%s_seed_%s_acc_%s.csv' % (self.n_data, self.n_model, self.n_aug, self.num_training_data, self.seed, top1_avg_max), results, delimiter=',') if self.flag_variance == 1: np.savetxt('results/loss_variance_judge_%s_aug_%s_acc_%s.csv' % (self.judge_noise, self.n_aug, top1_avg_max), loss_fixed_all, delimiter=',')
def train(self, dataloader, temperature, ckpt_path, n_epochs=90, save_size=10): # trainers criterion = NTCrossEntropyLoss(temperature, self.batch_size, self.device).to(self.device) optimizer = LARS(torch.optim.SGD(self.model.parameters(), lr=4)) # optimizer = optimizer.to(self.device) losses = [] for epoch in range(n_epochs): with tqdm(total=len(dataloader)) as progress: running_loss = 0 i = 0 for (xis, xjs), _ in dataloader: i += 1 optimizer.zero_grad() xis = xis.to(self.device) xjs = xjs.to(self.device) # Get representations and projections his, zis = self.model(xis) hjs, zjs = self.model(xjs) # normalize zis = F.normalize(zis, dim=1) zjs = F.normalize(zjs, dim=1) loss = criterion(zis, zjs) running_loss += loss.item() # optimize loss.backward() optimizer.step() # update tqdm progress.set_description('train loss:{:.4f}'.format( loss.item())) progress.update() # record loss if i % save_size == (save_size - 1): losses.append(running_loss / save_size) running_loss = 0 # save model if epoch % 10 == 0: torch.save( { 'epoch': epoch, 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': losses, }, ckpt_path + f'{epoch}') return self.return_model(), losses
return v_loss model = resnet50_cifar(args.feature_size).type(dtype) if args.multi_gpu: model = torch.nn.DataParallel(model, device_ids=[4, 5, 6, 7]) print('Multi gpu') # init? base_optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.decay_lr) optimizer = LARS(optimizer=base_optimizer, eps=1e-8, trust_coef=0.001) scheduler = ExponentialLR(optimizer, gamma=args.decay_lr) # Main training loop best_loss = np.inf # Resume training if args.load_model is not None: if os.path.isfile(args.load_model): checkpoint = torch.load(args.load_model) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) base_optimizer.load_state_dict(checkpoint['base_optimizer']) scheduler.load_state_dict(checkpoint['scheduler']) best_loss = checkpoint['val_loss'] epoch = checkpoint['epoch']
def main_worker(gpu, ngpus_per_node, args): args.gpu = gpu if args.distributed: if args.dist_url == 'env://' and args.rank == -1: args.rank = int(os.environ["RANK"]) if args.multiprocessing_distributed: args.rank = args.rank * ngpus_per_node + gpu dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank) model = PixPro( encoder=resnet50, dim1 = args.pcl_dim_1, dim2 = args.pcl_dim_2, momentum = args.encoder_momentum, threshold = args.threshold, temperature = args.T, sharpness = args.sharpness , num_linear = args.num_linear, ) args.lr = args.lr_base * args.batch_size/256 if args.distributed: if args.gpu is not None: torch.cuda.set_device(args.gpu) model.cuda(args.gpu) args.batch_size = int(args.batch_size / ngpus_per_node) args.workers = int((args.workers + ngpus_per_node -1) / ngpus_per_node) # convert batch norm --> sync batch norm sync_bn_model = nn.SyncBatchNorm.convert_sync_batchnorm(model) model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) else: model.cuda() model = torch.nn.parallel.DistributedDataParallel(moel) elif args.gpu is not None: torch.cuda.set_device(args.gpu) model = model.cuda(args.gpu) else: raise NotImplementedError('only DDP is supported.') base_optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) optimizer = LARS(optimizer=base_optimizer, eps=1e-8) writer = SummaryWriter(args.log_dir) if args.resume: checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) cudnn.benchmark = True dataset = PixProDataset(root=args.train_path, args=args) if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) else: train_sampler = None loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True) for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) adjust_lr(optimizer, epoch, args) train(args, epoch, loader, model, optimizer, writer) if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0): save_name = '{}.pth.tar'.format(epoch) save_name = os.path.join(args.checkpoint_dir, save_name) torch.save({ 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optimizer' : optimizer.state_dict(), }, save_name)
def prepare_train_eval(rank, world_size, run_name, train_config, model_config, hdf5_path_train): cfgs = dict2clsattr(train_config, model_config) prev_ada_p, step, best_step, best_fid, best_fid_checkpoint_path, mu, sigma, inception_model = None, 0, 0, None, None, None, None, None if cfgs.distributed_data_parallel: print("Use GPU: {} for training.".format(rank)) setup(rank, world_size) torch.cuda.set_device(rank) writer = SummaryWriter( log_dir=join('./logs', run_name)) if rank == 0 else None if rank == 0: logger = make_logger(run_name, None) logger.info('Run name : {run_name}'.format(run_name=run_name)) logger.info(train_config) logger.info(model_config) else: logger = None ##### load dataset ##### if rank == 0: logger.info('Load train datasets...') train_dataset = LoadDataset(cfgs.dataset_name, cfgs.data_path, train=True, download=True, resize_size=cfgs.img_size, hdf5_path=hdf5_path_train, random_flip=cfgs.random_flip_preprocessing) if cfgs.reduce_train_dataset < 1.0: num_train = int(cfgs.reduce_train_dataset * len(train_dataset)) train_dataset, _ = torch.utils.data.random_split( train_dataset, [num_train, len(train_dataset) - num_train]) if rank == 0: logger.info('Train dataset size : {dataset_size}'.format( dataset_size=len(train_dataset))) if rank == 0: logger.info('Load {mode} datasets...'.format(mode=cfgs.eval_type)) eval_mode = True if cfgs.eval_type == 'train' else False eval_dataset = LoadDataset(cfgs.dataset_name, cfgs.data_path, train=eval_mode, download=True, resize_size=cfgs.img_size, hdf5_path=None, random_flip=False) if rank == 0: logger.info('Eval dataset size : {dataset_size}'.format( dataset_size=len(eval_dataset))) if cfgs.distributed_data_parallel: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset) cfgs.batch_size = cfgs.batch_size // world_size else: train_sampler = None train_dataloader = DataLoader(train_dataset, batch_size=cfgs.batch_size, shuffle=(train_sampler is None), pin_memory=True, num_workers=cfgs.num_workers, sampler=train_sampler, drop_last=True) eval_dataloader = DataLoader(eval_dataset, batch_size=cfgs.batch_size, shuffle=False, pin_memory=True, num_workers=cfgs.num_workers, drop_last=False) ##### build model ##### if rank == 0: logger.info('Build model...') module = __import__( 'models.{architecture}'.format(architecture=cfgs.architecture), fromlist=['something']) if rank == 0: logger.info('Modules are located on models.{architecture}.'.format( architecture=cfgs.architecture)) Gen = module.Generator(cfgs.z_dim, cfgs.shared_dim, cfgs.img_size, cfgs.g_conv_dim, cfgs.g_spectral_norm, cfgs.attention, cfgs.attention_after_nth_gen_block, cfgs.activation_fn, cfgs.conditional_strategy, cfgs.num_classes, cfgs.g_init, cfgs.G_depth, cfgs.mixed_precision).to(rank) Dis = module.Discriminator( cfgs.img_size, cfgs.d_conv_dim, cfgs.d_spectral_norm, cfgs.attention, cfgs.attention_after_nth_dis_block, cfgs.activation_fn, cfgs.conditional_strategy, cfgs.hypersphere_dim, cfgs.num_classes, cfgs.nonlinear_embed, cfgs.normalize_embed, cfgs.d_init, cfgs.D_depth, cfgs.mixed_precision).to(rank) if cfgs.ema: if rank == 0: logger.info('Prepare EMA for G with decay of {}.'.format( cfgs.ema_decay)) Gen_copy = module.Generator( cfgs.z_dim, cfgs.shared_dim, cfgs.img_size, cfgs.g_conv_dim, cfgs.g_spectral_norm, cfgs.attention, cfgs.attention_after_nth_gen_block, cfgs.activation_fn, cfgs.conditional_strategy, cfgs.num_classes, initialize=False, G_depth=cfgs.G_depth, mixed_precision=cfgs.mixed_precision).to(rank) Gen_ema = ema(Gen, Gen_copy, cfgs.ema_decay, cfgs.ema_start) else: Gen_copy, Gen_ema = None, None if rank == 0: logger.info(count_parameters(Gen)) if rank == 0: logger.info(Gen) if rank == 0: logger.info(count_parameters(Dis)) if rank == 0: logger.info(Dis) ### define loss functions and optimizers G_loss = { 'vanilla': loss_dcgan_gen, 'least_square': loss_lsgan_gen, 'hinge': loss_hinge_gen, 'wasserstein': loss_wgan_gen } D_loss = { 'vanilla': loss_dcgan_dis, 'least_square': loss_lsgan_dis, 'hinge': loss_hinge_dis, 'wasserstein': loss_wgan_dis } if cfgs.optimizer == "SGD": G_optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, Gen.parameters()), cfgs.g_lr, momentum=cfgs.momentum, nesterov=cfgs.nesterov) D_optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, Dis.parameters()), cfgs.d_lr, momentum=cfgs.momentum, nesterov=cfgs.nesterov) elif cfgs.optimizer == "RMSprop": G_optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad, Gen.parameters()), cfgs.g_lr, momentum=cfgs.momentum, alpha=cfgs.alpha) D_optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad, Dis.parameters()), cfgs.d_lr, momentum=cfgs.momentum, alpha=cfgs.alpha) elif cfgs.optimizer == "Adam": G_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, Gen.parameters()), cfgs.g_lr, [cfgs.beta1, cfgs.beta2], eps=1e-6) D_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, Dis.parameters()), cfgs.d_lr, [cfgs.beta1, cfgs.beta2], eps=1e-6) else: raise NotImplementedError if cfgs.LARS_optimizer: G_optimizer = LARS(optimizer=G_optimizer, eps=1e-8, trust_coef=0.001) D_optimizer = LARS(optimizer=D_optimizer, eps=1e-8, trust_coef=0.001) ##### load checkpoints if needed ##### if cfgs.checkpoint_folder is None: checkpoint_dir = make_checkpoint_dir(cfgs.checkpoint_folder, run_name) else: when = "current" if cfgs.load_current is True else "best" if not exists(abspath(cfgs.checkpoint_folder)): raise NotADirectoryError checkpoint_dir = make_checkpoint_dir(cfgs.checkpoint_folder, run_name) g_checkpoint_dir = glob.glob( join(checkpoint_dir, "model=G-{when}-weights-step*.pth".format(when=when)))[0] d_checkpoint_dir = glob.glob( join(checkpoint_dir, "model=D-{when}-weights-step*.pth".format(when=when)))[0] Gen, G_optimizer, trained_seed, run_name, step, prev_ada_p = load_checkpoint( Gen, G_optimizer, g_checkpoint_dir) Dis, D_optimizer, trained_seed, run_name, step, prev_ada_p, best_step, best_fid, best_fid_checkpoint_path =\ load_checkpoint(Dis, D_optimizer, d_checkpoint_dir, metric=True) if rank == 0: logger = make_logger(run_name, None) if cfgs.ema: g_ema_checkpoint_dir = glob.glob( join(checkpoint_dir, "model=G_ema-{when}-weights-step*.pth".format( when=when)))[0] Gen_copy = load_checkpoint(Gen_copy, None, g_ema_checkpoint_dir, ema=True) Gen_ema.source, Gen_ema.target = Gen, Gen_copy writer = SummaryWriter( log_dir=join('./logs', run_name)) if rank == 0 else None if cfgs.train_configs['train']: assert cfgs.seed == trained_seed, "Seed for sampling random numbers should be same!" if rank == 0: logger.info('Generator checkpoint is {}'.format(g_checkpoint_dir)) if rank == 0: logger.info( 'Discriminator checkpoint is {}'.format(d_checkpoint_dir)) if cfgs.freeze_layers > -1: prev_ada_p, step, best_step, best_fid, best_fid_checkpoint_path = None, 0, 0, None, None ##### wrap models with DP and convert BN to Sync BN ##### if world_size > 1: if cfgs.distributed_data_parallel: if cfgs.synchronized_bn: process_group = torch.distributed.new_group( [w for w in range(world_size)]) Gen = torch.nn.SyncBatchNorm.convert_sync_batchnorm( Gen, process_group) Dis = torch.nn.SyncBatchNorm.convert_sync_batchnorm( Dis, process_group) if cfgs.ema: Gen_copy = torch.nn.SyncBatchNorm.convert_sync_batchnorm( Gen_copy, process_group) Gen = DDP(Gen, device_ids=[rank], broadcast_buffers=False, find_unused_parameters=True) Dis = DDP(Dis, device_ids=[rank], broadcast_buffers=False, find_unused_parameters=True) if cfgs.ema: Gen_copy = DDP(Gen_copy, device_ids=[rank], broadcast_buffers=False, find_unused_parameters=True) else: Gen = DataParallel(Gen, output_device=rank) Dis = DataParallel(Dis, output_device=rank) if cfgs.ema: Gen_copy = DataParallel(Gen_copy, output_device=rank) if cfgs.synchronized_bn: Gen = convert_model(Gen).to(rank) Dis = convert_model(Dis).to(rank) if cfgs.ema: Gen_copy = convert_model(Gen_copy).to(rank) ##### load the inception network and prepare first/secend moments for calculating FID ##### if cfgs.eval: inception_model = InceptionV3().to(rank) if world_size > 1 and cfgs.distributed_data_parallel: toggle_grad(inception_model, on=True) inception_model = DDP(inception_model, device_ids=[rank], broadcast_buffers=False, find_unused_parameters=True) elif world_size > 1 and cfgs.distributed_data_parallel is False: inception_model = DataParallel(inception_model, output_device=rank) else: pass mu, sigma = prepare_inception_moments(dataloader=eval_dataloader, generator=Gen, eval_mode=cfgs.eval_type, inception_model=inception_model, splits=1, run_name=run_name, logger=logger, device=rank) worker = make_worker( cfgs=cfgs, run_name=run_name, best_step=best_step, logger=logger, writer=writer, n_gpus=world_size, gen_model=Gen, dis_model=Dis, inception_model=inception_model, Gen_copy=Gen_copy, Gen_ema=Gen_ema, train_dataset=train_dataset, eval_dataset=eval_dataset, train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, G_optimizer=G_optimizer, D_optimizer=D_optimizer, G_loss=G_loss[cfgs.adv_loss], D_loss=D_loss[cfgs.adv_loss], prev_ada_p=prev_ada_p, rank=rank, checkpoint_dir=checkpoint_dir, mu=mu, sigma=sigma, best_fid=best_fid, best_fid_checkpoint_path=best_fid_checkpoint_path, ) if cfgs.train_configs['train']: step = worker.train(current_step=step, total_step=cfgs.total_step) if cfgs.eval: is_save = worker.evaluation( step=step, standing_statistics=cfgs.standing_statistics, standing_step=cfgs.standing_step) if cfgs.save_images: worker.save_images(is_generate=True, png=True, npz=True, standing_statistics=cfgs.standing_statistics, standing_step=cfgs.standing_step) if cfgs.image_visualization: worker.run_image_visualization( nrow=cfgs.nrow, ncol=cfgs.ncol, standing_statistics=cfgs.standing_statistics, standing_step=cfgs.standing_step) if cfgs.k_nearest_neighbor: worker.run_nearest_neighbor( nrow=cfgs.nrow, ncol=cfgs.ncol, standing_statistics=cfgs.standing_statistics, standing_step=cfgs.standing_step) if cfgs.interpolation: assert cfgs.architecture in [ "big_resnet", "biggan_deep" ], "StudioGAN does not support interpolation analysis except for biggan and biggan_deep." worker.run_linear_interpolation( nrow=cfgs.nrow, ncol=cfgs.ncol, fix_z=True, fix_y=False, standing_statistics=cfgs.standing_statistics, standing_step=cfgs.standing_step) worker.run_linear_interpolation( nrow=cfgs.nrow, ncol=cfgs.ncol, fix_z=False, fix_y=True, standing_statistics=cfgs.standing_statistics, standing_step=cfgs.standing_step) if cfgs.frequency_analysis: worker.run_frequency_analysis( num_images=len(train_dataset) // cfgs.num_classes, standing_statistics=cfgs.standing_statistics, standing_step=cfgs.standing_step) if cfgs.tsne_analysis: worker.run_tsne(dataloader=eval_dataloader, standing_statistics=cfgs.standing_statistics, standing_step=cfgs.standing_step)
def train(cfg, writer, logger): # Setup random seeds to a determinated value for reproduction # seed = 1337 # torch.manual_seed(seed) # torch.cuda.manual_seed(seed) # np.random.seed(seed) # random.seed(seed) # np.random.default_rng(seed) # Setup Augmentations augmentations = cfg.train.augment logger.info(f'using augments: {augmentations}') data_aug = get_composed_augmentations(augmentations) # Setup Dataloader data_loader = get_loader(cfg.data.dataloader) data_path = cfg.data.path logger.info("Using dataset: {}".format(data_path)) t_loader = data_loader( data_path, # transform=None, # time_shuffle = cfg.data.time_shuffle, # to_tensor=False, data_format = cfg.data.format, split=cfg.data.train_split, norm = cfg.data.norm, augments=data_aug ) v_loader = data_loader( data_path, # transform=None, # time_shuffle = cfg.data.time_shuffle, # to_tensor=False, data_format = cfg.data.format, split=cfg.data.val_split, ) train_data_len = len(t_loader) logger.info(f'num of train samples: {train_data_len} \nnum of val samples: {len(v_loader)}') batch_size = cfg.train.batch_size epoch = cfg.train.epoch train_iter = int(np.ceil(train_data_len / batch_size) * epoch) logger.info(f'total train iter: {train_iter}') trainloader = data.DataLoader(t_loader, batch_size=batch_size, num_workers=cfg.train.n_workers, shuffle=True, persistent_workers=True, drop_last=True) valloader = data.DataLoader(v_loader, batch_size=10, # persis num_workers=cfg.train.n_workers,) # Setup Model device = f'cuda:{cfg.gpu[0]}' model = get_model(cfg.model, 2).to(device) input_size = (cfg.model.input_nbr, 512, 512) logger.info(f"Using Model: {cfg.model.arch}") # logger.info(f'model summary: {summary(model, input_size=(input_size, input_size), is_complex=True)}') model = torch.nn.DataParallel(model, device_ids=cfg.gpu) #自动多卡运行,这个好用 # Setup optimizer, lr_scheduler and loss function optimizer_cls = get_optimizer(cfg) optimizer_params = {k:v for k, v in vars(cfg.train.optimizer).items() if k not in ('name', 'wrap')} optimizer = optimizer_cls(model.parameters(), **optimizer_params) logger.info("Using optimizer {}".format(optimizer)) if hasattr(cfg.train.optimizer, 'warp') and cfg.train.optimizer.wrap=='lars': optimizer = LARS(optimizer=optimizer) logger.info(f'warp optimizer with {cfg.train.optimizer.wrap}') scheduler = get_scheduler(optimizer, cfg.train.lr) loss_fn = get_loss_function(cfg) logger.info(f"Using loss ,{str(cfg.train.loss)}") # load checkpoints val_cls_1_acc = 0 best_cls_1_acc_now = 0 best_cls_1_acc_iter_now = 0 val_macro_OA = 0 best_macro_OA_now = 0 best_macro_OA_iter_now = 0 start_iter = 0 if cfg.train.resume is not None: if os.path.isfile(cfg.train.resume): logger.info( "Loading model and optimizer from checkpoint '{}'".format(cfg.train.resume) ) # load model state checkpoint = torch.load(cfg.train.resume) model.load_state_dict(checkpoint["model_state"]) optimizer.load_state_dict(checkpoint["optimizer_state"]) scheduler.load_state_dict(checkpoint["scheduler_state"]) # best_cls_1_acc_now = checkpoint["best_cls_1_acc_now"] # best_cls_1_acc_iter_now = checkpoint["best_cls_1_acc_iter_now"] start_iter = checkpoint["epoch"] logger.info( "Loaded checkpoint '{}' (iter {})".format( cfg.train.resume, checkpoint["epoch"] ) ) # copy tensorboard files resume_src_dir = osp.split(cfg.train.resume)[0] # shutil.copytree(resume_src_dir, writer.get_logdir()) for file in os.listdir(resume_src_dir): if not ('.log' in file or '.yml' in file or '_last_model' in file): # if 'events.out.tfevents' in file: resume_dst_dir = writer.get_logdir() fu.copy(osp.join(resume_src_dir, file), resume_dst_dir, ) else: logger.info("No checkpoint found at '{}'".format(cfg.train.resume)) # Setup Metrics running_metrics_val = runningScore(2) runing_metrics_train = runningScore(2) val_loss_meter = averageMeter() train_time_meter = averageMeter() # train it = start_iter train_start_time = time.time() train_val_start_time = time.time() model.train() while it < train_iter: for (file_a, file_b, label, mask) in trainloader: it += 1 file_a = file_a.to(device) file_b = file_b.to(device) label = label.to(device) mask = mask.to(device) optimizer.zero_grad() # print(f'dtype: {file_a.dtype}') outputs = model(file_a, file_b) loss = loss_fn(input=outputs, target=label, mask=mask) loss.backward() # print('conv11: ', model.conv11.weight.grad, model.conv11.weight.grad.shape) # print('conv21: ', model.conv21.weight.grad, model.conv21.weight.grad.shape) # print('conv31: ', model.conv31.weight.grad, model.conv31.weight.grad.shape) # In PyTorch 1.1.0 and later, you should call `optimizer.step()` before `lr_scheduler.step()` optimizer.step() scheduler.step() # record the acc of the minibatch pred = outputs.max(1)[1].cpu().numpy() runing_metrics_train.update(label.cpu().numpy(), pred, mask.cpu().numpy()) train_time_meter.update(time.time() - train_start_time) if it % cfg.train.print_interval == 0: # acc of the samples between print_interval score, _ = runing_metrics_train.get_scores() train_cls_0_acc, train_cls_1_acc = score['Acc'] fmt_str = "Iter [{:d}/{:d}] train Loss: {:.4f} Time/Image: {:.4f},\n0:{:.4f}\n1:{:.4f}" print_str = fmt_str.format(it, train_iter, loss.item(), #extracts the loss’s value as a Python float. train_time_meter.avg / cfg.train.batch_size,train_cls_0_acc, train_cls_1_acc) runing_metrics_train.reset() train_time_meter.reset() logger.info(print_str) writer.add_scalar('loss/train_loss', loss.item(), it) writer.add_scalars('metrics/train', {'cls_0':train_cls_0_acc, 'cls_1':train_cls_1_acc}, it) # writer.add_scalar('train_metrics/acc/cls_0', train_cls_0_acc, it) # writer.add_scalar('train_metrics/acc/cls_1', train_cls_1_acc, it) if it % cfg.train.val_interval == 0 or \ it == train_iter: val_start_time = time.time() model.eval() # change behavior like drop out with torch.no_grad(): # disable autograd, save memory usage for (file_a_val, file_b_val, label_val, mask_val) in valloader: file_a_val = file_a_val.to(device) file_b_val = file_b_val.to(device) outputs = model(file_a_val, file_b_val) # tensor.max() returns the maximum value and its indices pred = outputs.max(1)[1].cpu().numpy() running_metrics_val.update(label_val.numpy(), pred, mask_val.numpy()) label_val = label_val.to(device) mask_val = mask_val.to(device) val_loss = loss_fn(input=outputs, target=label_val, mask=mask_val) val_loss_meter.update(val_loss.item()) score, _ = running_metrics_val.get_scores() val_cls_0_acc, val_cls_1_acc = score['Acc'] writer.add_scalar('loss/val_loss', val_loss_meter.avg, it) logger.info(f"Iter [{it}/{train_iter}], val Loss: {val_loss_meter.avg:.4f} Time/Image: {(time.time()-val_start_time)/len(v_loader):.4f}\n0: {val_cls_0_acc:.4f}\n1:{val_cls_1_acc:.4f}") # lr_now = optimizer.param_groups[0]['lr'] # logger.info(f'lr: {lr_now}') # writer.add_scalar('lr', lr_now, it+1) logger.info('0: {:.4f}\n1:{:.4f}'.format(val_cls_0_acc, val_cls_1_acc)) writer.add_scalars('metrics/val', {'cls_0':val_cls_0_acc, 'cls_1':val_cls_1_acc}, it) # writer.add_scalar('val_metrics/acc/cls_0', val_cls_0_acc, it) # writer.add_scalar('val_metrics/acc/cls_1', val_cls_1_acc, it) val_loss_meter.reset() running_metrics_val.reset() # OA=score["Overall_Acc"] val_macro_OA = (val_cls_0_acc+val_cls_1_acc)/2 if val_macro_OA >= best_macro_OA_now and it>200: best_macro_OA_now = val_macro_OA best_macro_OA_iter_now = it state = { "epoch": it, "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict(), "scheduler_state": scheduler.state_dict(), "best_macro_OA_now": best_macro_OA_now, 'best_macro_OA_iter_now':best_macro_OA_iter_now, } save_path = os.path.join(writer.file_writer.get_logdir(), "{}_{}_best_model.pkl".format(cfg.model.arch,cfg.data.dataloader)) torch.save(state, save_path) logger.info("best OA now = %.8f" % (best_macro_OA_now)) logger.info("best OA iter now= %d" % (best_macro_OA_iter_now)) train_val_time = time.time() - train_val_start_time remain_time = train_val_time * (train_iter-it) / it m, s = divmod(remain_time, 60) h, m = divmod(m, 60) if s != 0: train_time = "Remain train time = %d hours %d minutes %d seconds \n" % (h, m, s) else: train_time = "Remain train time : train completed.\n" logger.info(train_time) model.train() train_start_time = time.time() logger.info("best OA now = %.8f" % (best_macro_OA_now)) logger.info("best OA iter now= %d" % (best_macro_OA_iter_now)) state = { "epoch": it, "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict(), "scheduler_state": scheduler.state_dict(), "best_macro_OA_now": best_macro_OA_now, 'best_macro_OA_iter_now':best_macro_OA_iter_now, } save_path = os.path.join(writer.file_writer.get_logdir(), "{}_{}_last_model.pkl".format(cfg.model.arch, cfg.data.dataloader)) torch.save(state, save_path)
def train_and_eval(tag, dataroot, test_ratio=0.0, cv_fold=0, reporter=None, metric='last', save_path=None, only_eval=False): if not reporter: reporter = lambda **kwargs: 0 max_epoch = C.get()['epoch'] trainsampler, trainloader, validloader, testloader_ = get_dataloaders( C.get()['dataset'], C.get()['batch'], dataroot, test_ratio, split_idx=cv_fold) # create a model & an optimizer model = get_model(C.get()['model'], num_class(C.get()['dataset'])) lb_smooth = C.get()['optimizer'].get('label_smoothing', 0.0) if lb_smooth > 0.0: criterion = SmoothCrossEntropyLoss(lb_smooth) else: criterion = nn.CrossEntropyLoss() if C.get()['optimizer']['type'] == 'sgd': optimizer = optim.SGD(model.parameters(), lr=C.get()['lr'], momentum=C.get()['optimizer'].get( 'momentum', 0.9), weight_decay=C.get()['optimizer']['decay'], nesterov=C.get()['optimizer']['nesterov']) else: raise ValueError('invalid optimizer type=%s' % C.get()['optimizer']['type']) if C.get()['optimizer'].get('lars', False): from torchlars import LARS optimizer = LARS(optimizer) logger.info('*** LARS Enabled.') lr_scheduler_type = C.get()['lr_schedule'].get('type', 'cosine') if lr_scheduler_type == 'cosine': scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=C.get()['epoch'], eta_min=0.) elif lr_scheduler_type == 'resnet': scheduler = adjust_learning_rate_resnet(optimizer) else: raise ValueError('invalid lr_schduler=%s' % lr_scheduler_type) if C.get()['lr_schedule'].get('warmup', None): scheduler = GradualWarmupScheduler( optimizer, multiplier=C.get()['lr_schedule']['warmup']['multiplier'], total_epoch=C.get()['lr_schedule']['warmup']['epoch'], after_scheduler=scheduler) if not tag: from RandAugment.metrics import SummaryWriterDummy as SummaryWriter logger.warning('tag not provided, no tensorboard log.') else: from tensorboardX import SummaryWriter writers = [ SummaryWriter(log_dir='./logs/%s/%s' % (tag, x)) for x in ['train', 'valid', 'test'] ] result = OrderedDict() epoch_start = 1 if save_path and os.path.exists(save_path): logger.info('%s file found. loading...' % save_path) data = torch.load(save_path) if 'model' in data or 'state_dict' in data: key = 'model' if 'model' in data else 'state_dict' logger.info('checkpoint epoch@%d' % data['epoch']) if not isinstance(model, DataParallel): model.load_state_dict({ k.replace('module.', ''): v for k, v in data[key].items() }) else: model.load_state_dict({ k if 'module.' in k else 'module.' + k: v for k, v in data[key].items() }) optimizer.load_state_dict(data['optimizer']) if data['epoch'] < C.get()['epoch']: epoch_start = data['epoch'] else: only_eval = True else: model.load_state_dict({k: v for k, v in data.items()}) del data else: logger.info('"%s" file not found. skip to pretrain weights...' % save_path) if only_eval: logger.warning( 'model checkpoint not found. only-evaluation mode is off.') only_eval = False if only_eval: logger.info('evaluation only+') model.eval() rs = dict() rs['train'] = run_epoch(model, trainloader, criterion, None, desc_default='train', epoch=0, writer=writers[0]) rs['valid'] = run_epoch(model, validloader, criterion, None, desc_default='valid', epoch=0, writer=writers[1]) rs['test'] = run_epoch(model, testloader_, criterion, None, desc_default='*test', epoch=0, writer=writers[2]) for key, setname in itertools.product(['loss', 'top1', 'top5'], ['train', 'valid', 'test']): if setname not in rs: continue result['%s_%s' % (key, setname)] = rs[setname][key] result['epoch'] = 0 return result # train loop best_top1 = 0 for epoch in range(epoch_start, max_epoch + 1): model.train() rs = dict() rs['train'] = run_epoch(model, trainloader, criterion, optimizer, desc_default='train', epoch=epoch, writer=writers[0], verbose=True, scheduler=scheduler) model.eval() if math.isnan(rs['train']['loss']): raise Exception('train loss is NaN.') if epoch % 5 == 0 or epoch == max_epoch: rs['valid'] = run_epoch(model, validloader, criterion, None, desc_default='valid', epoch=epoch, writer=writers[1], verbose=True) rs['test'] = run_epoch(model, testloader_, criterion, None, desc_default='*test', epoch=epoch, writer=writers[2], verbose=True) if metric == 'last' or rs[metric]['top1'] > best_top1: if metric != 'last': best_top1 = rs[metric]['top1'] for key, setname in itertools.product( ['loss', 'top1', 'top5'], ['train', 'valid', 'test']): result['%s_%s' % (key, setname)] = rs[setname][key] result['epoch'] = epoch writers[1].add_scalar('valid_top1/best', rs['valid']['top1'], epoch) writers[2].add_scalar('test_top1/best', rs['test']['top1'], epoch) reporter(loss_valid=rs['valid']['loss'], top1_valid=rs['valid']['top1'], loss_test=rs['test']['loss'], top1_test=rs['test']['top1']) # save checkpoint if save_path: logger.info('save model@%d to %s' % (epoch, save_path)) torch.save( { 'epoch': epoch, 'log': { 'train': rs['train'].get_dict(), 'valid': rs['valid'].get_dict(), 'test': rs['test'].get_dict(), }, 'optimizer': optimizer.state_dict(), 'model': model.state_dict() }, save_path) torch.save( { 'epoch': epoch, 'log': { 'train': rs['train'].get_dict(), 'valid': rs['valid'].get_dict(), 'test': rs['test'].get_dict(), }, 'optimizer': optimizer.state_dict(), 'model': model.state_dict() }, save_path.replace( '.pth', '_e%d_top1_%.3f_%.3f' % (epoch, rs['train']['top1'], rs['test']['top1']) + '.pth')) del model result['top1_test'] = best_top1 return result
def main(): # Training settings parser = argparse.ArgumentParser(description='PyTorch MNIST Example') parser.add_argument('--batch-size', type=int, default=4096, metavar='N', help='input batch size for training (default: 64)') parser.add_argument('--optimizer', type=str, default='lamb', choices=['lamb', 'adam', 'lars', 'sgd'], help='which optimizer to use') parser.add_argument('--test-batch-size', type=int, default=2048, metavar='N', help='input batch size for testing (default: 1000)') parser.add_argument('--epochs', type=int, default=100, metavar='N', help='number of epochs to train (default: 100)') parser.add_argument('--lr', type=float, default=0.001, metavar='LR', help='learning rate (default: 0.0025)') parser.add_argument('--wd', type=float, default=0.01, metavar='WD', help='weight decay (default: 0.01)') parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') parser.add_argument('--eta', type=int, default=0.001, metavar='e', help='LARS coefficient (default: 0.001)') parser.add_argument( '--log-interval', type=int, default=10, metavar='N', help='how many batches to wait before logging training status') args = parser.parse_args() use_cuda = torch.cuda.is_available() torch.manual_seed(args.seed) device = torch.device("cuda" if use_cuda else "cpu") print(device) print(args) print("*" * 50) kwargs = {'num_workers': 4, 'pin_memory': True} if use_cuda else {} train_loader = torch.utils.data.DataLoader( # datasets.MNIST('../data', train=True, download=True, datasets.EMNIST('../data', train=True, download=True, split='letters', transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ])), batch_size=args.batch_size, shuffle=True, **kwargs) test_loader = torch.utils.data.DataLoader( # datasets.MNIST('../data', train=False, transform=transforms.Compose([ datasets.EMNIST('../data', train=False, split='letters', transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ])), batch_size=args.test_batch_size, shuffle=True, **kwargs) num_features = 26 model = Net(num_outputs=num_features).to(device) writer = SummaryWriter(comment="_cv_%s_%s_%s" % (args.optimizer, args.batch_size, args.lr)) weight_decay = args.lr / args.epochs print(len(train_loader), len(test_loader)) print(model) print(f'total params ---> {count_parameters(model)}') if args.optimizer == 'lamb': optimizer = Lamb(model.parameters(), lr=args.lr, weight_decay=args.wd, betas=(.9, .999), adam=False, writer=writer) elif args.optimizer == 'lars': base_optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9) optimizer = LARS(optimizer=base_optimizer, eps=1e-8, trust_coef=0.001, writer=writer) elif args.optimizer == 'sgd': optimizer = SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.wd, writer=writer) else: # use adam optimizer optimizer = Lamb(model.parameters(), lr=args.lr, weight_decay=args.wd, betas=(.9, .999), adam=True, writer=writer) print(f'Currently using the {args.optimizer}\n\n') metrics = {"acc": [], "test_loss": []} os.makedirs("cv_results", exist_ok=True) for epoch in range(1, args.epochs + 1): print("Epoch #%s" % epoch) train(args, model, device, train_loader, optimizer, epoch, writer) acc, loss = test(args, model, device, test_loader, writer, epoch) metrics["acc"].append(acc) metrics["test_loss"].append(loss) pickle.dump( metrics, open( os.path.join( "cv_results", '%s_%s_metrics.p' % (args.optimizer, args.batch_size)), 'wb')) #print("Epoch #$s: acc=%s, loss=%s" % (epoch, acc, loss)) return optimizer
def train(cfg, writer, logger): # Setup Augmentations augmentations = cfg.train.augment logger.info(f'using augments: {augmentations}') data_aug = get_composed_augmentations(augmentations) # Setup Dataloader data_loader = get_loader(cfg.data.dataloader) data_path = cfg.data.path logger.info("data path: {}".format(data_path)) t_loader = data_loader( data_path, data_format=cfg.data.format, norm=cfg.data.norm, split='train', split_root=cfg.data.split, augments=data_aug, logger=logger, log=cfg.data.log, ENL=cfg.data.ENL, ) v_loader = data_loader( data_path, data_format=cfg.data.format, split='val', log=cfg.data.log, split_root=cfg.data.split, logger=logger, ENL=cfg.data.ENL, ) train_data_len = len(t_loader) logger.info( f'num of train samples: {train_data_len} \nnum of val samples: {len(v_loader)}' ) batch_size = cfg.train.batch_size epoch = cfg.train.epoch train_iter = int(np.ceil(train_data_len / batch_size) * epoch) logger.info(f'total train iter: {train_iter}') trainloader = data.DataLoader(t_loader, batch_size=batch_size, num_workers=cfg.train.n_workers, shuffle=True, persistent_workers=True, drop_last=True) valloader = data.DataLoader( v_loader, batch_size=cfg.test.batch_size, # persis num_workers=cfg.train.n_workers, ) # Setup Model device = f'cuda:{cfg.train.gpu[0]}' model = get_model(cfg.model).to(device) input_size = (cfg.model.in_channels, 512, 512) logger.info(f"Using Model: {cfg.model.arch}") # logger.info(f'model summary: {summary(model, input_size=(input_size, input_size), is_complex=False)}') model = torch.nn.DataParallel(model, device_ids=cfg.gpu) #自动多卡运行,这个好用 # Setup optimizer, lr_scheduler and loss function optimizer_cls = get_optimizer(cfg) optimizer_params = { k: v for k, v in vars(cfg.train.optimizer).items() if k not in ('name', 'wrap') } optimizer = optimizer_cls(model.parameters(), **optimizer_params) logger.info("Using optimizer {}".format(optimizer)) if hasattr(cfg.train.optimizer, 'wrap') and cfg.train.optimizer.wrap == 'lars': optimizer = LARS(optimizer=optimizer) logger.info(f'warp optimizer with {cfg.train.optimizer.wrap}') scheduler = get_scheduler(optimizer, cfg.train.lr) # loss_fn = get_loss_function(cfg) # logger.info(f"Using loss ,{str(cfg.train.loss)}") # load checkpoints val_cls_1_acc = 0 best_cls_1_acc_now = 0 best_cls_1_acc_iter_now = 0 val_macro_OA = 0 best_macro_OA_now = 0 best_macro_OA_iter_now = 0 start_iter = 0 if cfg.train.resume is not None: if os.path.isfile(cfg.train.resume): logger.info( "Loading model and optimizer from checkpoint '{}'".format( cfg.train.resume)) # load model state checkpoint = torch.load(cfg.train.resume) model.load_state_dict(checkpoint["model_state"]) optimizer.load_state_dict(checkpoint["optimizer_state"]) scheduler.load_state_dict(checkpoint["scheduler_state"]) # best_cls_1_acc_now = checkpoint["best_cls_1_acc_now"] # best_cls_1_acc_iter_now = checkpoint["best_cls_1_acc_iter_now"] start_iter = checkpoint["epoch"] logger.info("Loaded checkpoint '{}' (iter {})".format( cfg.train.resume, checkpoint["epoch"])) # copy tensorboard files resume_src_dir = osp.split(cfg.train.resume)[0] # shutil.copytree(resume_src_dir, writer.get_logdir()) for file in os.listdir(resume_src_dir): if not ('.log' in file or '.yml' in file or '_last_model' in file): # if 'events.out.tfevents' in file: resume_dst_dir = writer.get_logdir() fu.copy( osp.join(resume_src_dir, file), resume_dst_dir, ) else: logger.info("No checkpoint found at '{}'".format(cfg.train.resume)) data_range = 255 if cfg.data.log: data_range = np.log(data_range) # data_range /= 350 # Setup Metrics running_metrics_val = runningScore(2) runing_metrics_train = runningScore(2) val_loss_meter = averageMeter() train_time_meter = averageMeter() train_loss_meter = averageMeter() val_psnr_meter = averageMeter() val_ssim_meter = averageMeter() # train it = start_iter train_start_time = time.time() train_val_start_time = time.time() model.train() while it < train_iter: for clean, noisy, _ in trainloader: it += 1 noisy = noisy.to(device, dtype=torch.float32) # noisy /= 350 mask1, mask2 = rand_pool.generate_mask_pair(noisy) noisy_sub1 = rand_pool.generate_subimages(noisy, mask1) noisy_sub2 = rand_pool.generate_subimages(noisy, mask2) # preparing for the regularization term with torch.no_grad(): noisy_denoised = model(noisy) noisy_sub1_denoised = rand_pool.generate_subimages( noisy_denoised, mask1) noisy_sub2_denoised = rand_pool.generate_subimages( noisy_denoised, mask2) # print(rand_pool.operation_seed_counter) # for ii, param in enumerate(model.parameters()): # if torch.sum(torch.isnan(param.data)): # print(f'{ii}: nan parameters') # calculating the loss noisy_output = model(noisy_sub1) noisy_target = noisy_sub2 if cfg.train.loss.gamma.const: gamma = cfg.train.loss.gamma.base else: gamma = it / train_iter * cfg.train.loss.gamma.base diff = noisy_output - noisy_target exp_diff = noisy_sub1_denoised - noisy_sub2_denoised loss1 = torch.mean(diff**2) loss2 = gamma * torch.mean((diff - exp_diff)**2) loss_all = loss1 + loss2 # loss1 = noisy_output - noisy_target # loss2 = torch.exp(noisy_target - noisy_output) # loss_all = torch.mean(loss1 + loss2) loss_all.backward() # In PyTorch 1.1.0 and later, you should call `optimizer.step()` before `lr_scheduler.step()` optimizer.step() scheduler.step() # record the loss of the minibatch train_loss_meter.update(loss_all) train_time_meter.update(time.time() - train_start_time) writer.add_scalar('lr', optimizer.param_groups[0]['lr'], it) if it % 1000 == 0: writer.add_histogram('hist/pred', noisy_denoised, it) writer.add_histogram('hist/noisy', noisy, it) if cfg.data.simulate: writer.add_histogram('hist/clean', clean, it) if cfg.data.simulate: pass # print interval if it % cfg.train.print_interval == 0: terminal_info = f"Iter [{it:d}/{train_iter:d}] \ train Loss: {train_loss_meter.avg:.4f} \ Time/Image: {train_time_meter.avg / cfg.train.batch_size:.4f}" logger.info(terminal_info) writer.add_scalar('loss/train_loss', train_loss_meter.avg, it) if cfg.data.simulate: pass runing_metrics_train.reset() train_time_meter.reset() train_loss_meter.reset() # val interval if it % cfg.train.val_interval == 0 or \ it == train_iter: val_start_time = time.time() model.eval() with torch.no_grad(): for clean, noisy, _ in valloader: # noisy /= 350 # clean /= 350 noisy = noisy.to(device, dtype=torch.float32) noisy_denoised = model(noisy) if cfg.data.simulate: clean = clean.to(device, dtype=torch.float32) psnr = piq.psnr(clean, noisy_denoised, data_range=data_range) ssim = piq.ssim(clean, noisy_denoised, data_range=data_range) val_psnr_meter.update(psnr) val_ssim_meter.update(ssim) val_loss = torch.mean((noisy_denoised - noisy)**2) val_loss_meter.update(val_loss) writer.add_scalar('loss/val_loss', val_loss_meter.avg, it) logger.info( f"Iter [{it}/{train_iter}], val Loss: {val_loss_meter.avg:.4f} Time/Image: {(time.time()-val_start_time)/len(v_loader):.4f}" ) val_loss_meter.reset() running_metrics_val.reset() if cfg.data.simulate: writer.add_scalars('metrics/val', { 'psnr': val_psnr_meter.avg, 'ssim': val_ssim_meter.avg }, it) logger.info( f'psnr: {val_psnr_meter.avg},\tssim: {val_ssim_meter.avg}' ) val_psnr_meter.reset() val_ssim_meter.reset() train_val_time = time.time() - train_val_start_time remain_time = train_val_time * (train_iter - it) / it m, s = divmod(remain_time, 60) h, m = divmod(m, 60) if s != 0: train_time = "Remain train time = %d hours %d minutes %d seconds \n" % ( h, m, s) else: train_time = "Remain train time : train completed.\n" logger.info(train_time) model.train() # save model if it % (train_iter / cfg.train.epoch * 10) == 0: ep = int(it / (train_iter / cfg.train.epoch)) state = { "epoch": it, "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict(), "scheduler_state": scheduler.state_dict(), } save_path = osp.join(writer.file_writer.get_logdir(), f"{ep}.pkl") torch.save(state, save_path) logger.info(f'saved model state dict at {save_path}') train_start_time = time.time()
resume_from = os.path.join('./checkpoint', args.resume) checkpoint = torch.load(resume_from) net.load_state_dict(checkpoint['net']) critic.load_state_dict(checkpoint['critic']) best_acc = checkpoint['acc'] start_epoch = checkpoint['epoch'] criterion = nn.CrossEntropyLoss() base_optimizer = optim.SGD(list(net.parameters()) + list(critic.parameters()), lr=args.lr, weight_decay=1e-6, momentum=args.momentum) if args.cosine_anneal: scheduler = CosineAnnealingWithLinearRampLR(base_optimizer, args.num_epochs) encoder_optimizer = LARS(base_optimizer, trust_coef=1e-3) # Training def train(epoch): print('\nEpoch: %d' % epoch) net.train() critic.train() train_loss = 0 t = tqdm(enumerate(trainloader), desc='Loss: **** ', total=len(trainloader), bar_format='{desc}{bar}{r_bar}') for batch_idx, (inputs, _, _) in t: x1, x2 = inputs x1, x2 = x1.to(device), x2.to(device)
simclr_aug = C.get_simclr_augmentation(P, image_size=P.image_size).to(device) P.shift_trans, P.K_shift = C.get_shift_module(P, eval=True) P.shift_trans = P.shift_trans.to(device) model = C.get_classifier(P.model, n_classes=P.n_classes).to(device) model = C.get_shift_classifer(model, P.K_shift).to(device) criterion = nn.CrossEntropyLoss().to(device) if P.optimizer == 'sgd': optimizer = optim.SGD(model.parameters(), lr=P.lr_init, momentum=0.9, weight_decay=P.weight_decay) lr_decay_gamma = 0.1 elif P.optimizer == 'lars': from torchlars import LARS base_optimizer = optim.SGD(model.parameters(), lr=P.lr_init, momentum=0.9, weight_decay=P.weight_decay) optimizer = LARS(base_optimizer, eps=1e-8, trust_coef=0.001) lr_decay_gamma = 0.1 else: raise NotImplementedError() if P.lr_scheduler == 'cosine': scheduler = lr_scheduler.CosineAnnealingLR(optimizer, P.epochs) elif P.lr_scheduler == 'step_decay': milestones = [int(0.5 * P.epochs), int(0.75 * P.epochs)] scheduler = lr_scheduler.MultiStepLR(optimizer, gamma=lr_decay_gamma, milestones=milestones) else: raise NotImplementedError() from training.scheduler import GradualWarmupScheduler scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=10.0, total_epoch=P.warmup, after_scheduler=scheduler)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=True, pin_memory=True) loaders = [train_loader, val_loader] print('Total number of training images: {}, validation images: {}.'.format( len(train_dataset), len(val_dataset))) model = resnet18_cifar(args.feature_size).to(device) model = torch.nn.DataParallel(model) base_optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.decay_lr) optimizer = LARS(optimizer=base_optimizer, eps=1e-8, trust_coef=0.001) scheduler = ExponentialLR(optimizer, gamma=args.decay_lr) # Main training loop best_loss = np.inf for epoch in range(args.epochs): v_loss = execute_graph(model, loaders, optimizer, scheduler, epoch) if v_loss < best_loss: best_loss = v_loss print('Writing model checkpoint') state = { 'epoch': epoch, 'model': model.state_dict(), 'optimizer': optimizer.state_dict(),
class task2_train(): def __init__(self, device, train_loader, val_loader, model, MODEL_PATH): self.device = device self.model_path = MODEL_PATH self.train_loader = train_loader self.val_loader = val_loader self.model = model.to(device) base_optimizer = optim.Adam(self.model.parameters(), lr=2e-4) self.optimizer = LARS(optimizer=base_optimizer, eps=1e-8, trust_coef=0.001) self.criterion = nn.BCEWithLogitsLoss() self.scaler = GradScaler() def train(self): self.model.train() train_loss = 0 label_pred = [] label_true = [] AC_total, SE_total, SP_total, DC_total, JS_total = 0, 0, 0, 0, 0 start = time.time() for batch_idx, (inputs, GT) in enumerate(self.train_loader): if batch_idx % 10 == 0 and batch_idx: end = time.time() print("finished training %s images, using %.4fs" % (str(batch_idx * 8), end - start)) inputs, GT = inputs.to(self.device), GT.to(self.device) self.optimizer.zero_grad() with autocast(): SR = self.model(inputs) loss = loss_f(inputs=SR, targets=GT.long()) self.scaler.scale(loss).backward() self.scaler.step(self.optimizer) self.scaler.update() train_loss += loss.item() label_pred.append(SR.data.max(dim=1)[1].cpu().numpy()) label_true.append(GT.data.cpu().numpy()) label_pred = np.concatenate(label_pred, axis=0) label_true = np.concatenate(label_true, axis=0) for lbt, lbp in zip(label_true, label_pred): AC, SE, SP, DC, JS = self.evaluate(lbt, lbp) AC_total += AC SE_total += SE SP_total += SP DC_total += DC JS_total += JS print( 'Training Loss: %.4f\nAC: %.4f, SE: %.4f, SP: %.4f, DC: %.4f, JS: %.4f' % (train_loss, AC_total / len(label_true), SE_total / len(label_true), SP_total / len(label_true), DC_total / len(label_true), JS_total / len(label_true))) def _fast_hist(self, truth, pred): mask = (truth >= 0) & (truth < 19) hist = np.bincount(19 * truth[mask].astype(int) + pred[mask], minlength=19**2).reshape(19, 19) return hist def evaluate(self, ground_truth, predictions, smooth=1): confusion_matrix = np.zeros((19, 19)) for lt, lp in zip(ground_truth, predictions): confusion_matrix += self._fast_hist(lt.flatten(), lp.flatten()) fn = confusion_matrix.sum(1) - np.diag(confusion_matrix) fp = confusion_matrix.sum(0) - np.diag(confusion_matrix) tp = np.diag(confusion_matrix) tn = np.array([confusion_matrix.sum() for i in range(19) ]) - confusion_matrix.sum(1) - confusion_matrix.sum( 0) + np.diag(confusion_matrix) AC_array = (tp + tn) / np.maximum(1.0, fn + fp + tp + tn) AC = AC_array.mean() SE_array = (tp) / np.maximum(1.0, confusion_matrix.sum(1)) SE = SE_array.mean() SP_array = tn / np.maximum(1.0, tn + fp) SP = SP_array.mean() DC_array = (2 * tp) / np.maximum( 1.0, confusion_matrix.sum(0) + confusion_matrix.sum(1)) DC = DC_array.mean() JS_array = (tp) / np.maximum( 1.0, confusion_matrix.sum(0) + confusion_matrix.sum(1) - np.diag(confusion_matrix)) JS = JS_array.mean() return AC, SE, SP, DC, JS def saveModel(self): torch.save(self.model.state_dict(), self.model_path + "model.pth")
momentum=0.9, weight_decay=args.weight_decay) lr_decay_gamma = 0.1 elif args.optimizer == 'adam': optim = optim.Adam(model.parameters(), lr=args.lr, betas=(.9, .999), weight_decay=args.weight_decay) lr_decay_gamma = 0.3 elif args.optimizer == 'lars': from torchlars import LARS base_optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.weight_decay) optim = LARS(base_optimizer, eps=1e-8, trust_coef=0.001) lr_decay_gamma = 0.1 elif args.optimizer == 'ranger': from ranger import Ranger optim = Ranger(model.parameters(), weight_decay=args.weight_decay, lr=args.lr) else: raise NotImplementedError() # normal_class=[args.known_normal], known_outlier_class=args.known_outlier, print("known_normal:", args.known_normal, "known_outlier:", args.known_outlier) # if args.lr_scheduler == 'cosine': # scheduler = lr_scheduler.CosineAnnealingLR(optim, args.n_epochs) # print("use cosine scheduler") # Evaluation before training
args.add_argument('--seed', type=int) args.add_argument('--lr', type=float, default=1e-4) config = args.parse_args() if 'DEVICE' in os.environ: config.device = os.environ['DEVICE'] wandb.init(project='EfficientNet_MNIST', config=config) if config.seed is not None: torch.manual_seed(config.seed) eff_net = efficient_net.EfficientNet(version="b0", num_classes=10).to(config.device) if config.lars: optim = LARS(Adam(eff_net.parameters(), lr=config.lr)) else: optim = Adam(eff_net.parameters(), lr=config.lr) ds = MNIST('~/.mnist', train=True, download=True, transform=transforms.ToTensor()) dl = DataLoader(ds, batch_size=config.batch_size) test_s = MNIST('~/.mnist', train=False, download=True, transform=transforms.ToTensor()) test = DataLoader(test_s, batch_size=config.batch_size)