def main(): global best_acc start_epoch = args.start_epoch # start from epoch 0 or last checkpoint epoch if not os.path.isdir(args.checkpoint): mkdir_p(args.checkpoint) # Data print('==> Preparing dataset %s' % args.dataset) transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) if args.dataset == 'cifar10': dataloader = datasets.CIFAR10 num_classes = 10 else: dataloader = datasets.CIFAR100 num_classes = 100 trainset = dataloader(root='./data', train=True, download=True, transform=transform_train) trainloader = data.DataLoader(trainset, batch_size=args.train_batch, shuffle=True, num_workers=args.workers) testset = dataloader(root='./data', train=False, download=False, transform=transform_test) testloader = data.DataLoader(testset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers) # Model print("==> creating model '{}'".format(args.arch)) if args.arch.startswith('resnext'): model = models.__dict__[args.arch]( cardinality=args.cardinality, num_classes=num_classes, depth=args.depth, widen_factor=args.widen_factor, dropRate=args.drop, ) elif args.arch.startswith('densenet'): model = models.__dict__[args.arch]( num_classes=num_classes, depth=args.depth, growthRate=args.growthRate, compressionRate=args.compressionRate, dropRate=args.drop, ) elif args.arch.startswith('wrn'): model = models.__dict__[args.arch]( num_classes=num_classes, depth=args.depth, widen_factor=args.widen_factor, dropRate=args.drop, ) elif args.arch.startswith('resnet'): model = models.__dict__[args.arch]( num_classes=num_classes, depth=args.depth, block_name=args.block_name, ) elif args.arch.startswith('preresnet'): model = models.__dict__[args.arch]( num_classes=num_classes, depth=args.depth, block_name=args.block_name, ) elif args.arch.startswith('horesnet'): model = models.__dict__[args.arch](num_classes=num_classes, depth=args.depth, eta=args.eta, block_name=args.block_name, feature_vec=args.feature_vec) elif args.arch.startswith('hopreresnet'): model = models.__dict__[args.arch](num_classes=num_classes, depth=args.depth, eta=args.eta, block_name=args.block_name, feature_vec=args.feature_vec) elif args.arch.startswith('nagpreresnet'): model = models.__dict__[args.arch](num_classes=num_classes, depth=args.depth, eta=args.eta, block_name=args.block_name, feature_vec=args.feature_vec) elif args.arch.startswith('mompreresnet'): model = models.__dict__[args.arch](num_classes=num_classes, depth=args.depth, eta=args.eta, block_name=args.block_name, feature_vec=args.feature_vec) elif args.arch.startswith('v2_preresnet'): if args.depth == 18: block_name = 'basicblock' num_blocks = [2, 2, 2, 2] elif args.depth == 34: block_name = 'basicblock' num_blocks = [3, 4, 6, 3] elif args.depth == 50: block_name = 'bottleneck' num_blocks = [3, 4, 6, 3] elif args.depth == 101: block_name = 'bottleneck' num_blocks = [3, 4, 23, 3] elif args.depth == 152: block_name = 'bottleneck' num_blocks = [3, 8, 36, 3] model = models.__dict__[args.arch](block_name=block_name, num_blocks=num_blocks, num_classes=num_classes) else: print('Model is specified wrongly - Use standard model') model = models.__dict__[args.arch](num_classes=num_classes) model = torch.nn.DataParallel(model).cuda() cudnn.benchmark = True print(' Total params: %.2fM' % (sum(p.numel() for p in model.parameters()) / 1000000.0)) criterion = nn.CrossEntropyLoss() if args.optimizer.lower() == 'sgd': optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # elif args.optimizer.lower() == 'adam': # optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(args.beta1, args.beta2), weight_decay=args.weight_decay) elif args.optimizer.lower() == 'radam': optimizer = RAdam(model.parameters(), lr=args.lr, betas=(args.beta1, args.beta2), weight_decay=args.weight_decay) elif args.optimizer.lower() == 'adamw': optimizer = AdamW(model.parameters(), lr=args.lr, betas=(args.beta1, args.beta2), weight_decay=args.weight_decay, warmup=args.warmup) elif args.optimizer.lower() == 'adam': optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(args.beta1, args.beta2), weight_decay=args.weight_decay) elif args.optimizer.lower() == 'srsgd': iter_count = 1 optimizer = SGD_Adaptive(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, iter_count=iter_count, restarting_iter=args.restart_schedule[0]) elif args.optimizer.lower() == 'sradam': iter_count = 1 optimizer = SRNAdam(model.parameters(), lr=args.lr, betas=(args.beta1, args.beta2), iter_count=iter_count, weight_decay=args.weight_decay, restarting_iter=args.restart_schedule[0]) elif args.optimizer.lower() == 'sradamw': iter_count = 1 optimizer = SRAdamW(model.parameters(), lr=args.lr, betas=(args.beta1, args.beta2), iter_count=iter_count, weight_decay=args.weight_decay, warmup=args.warmup, restarting_iter=args.restart_schedule[0]) elif args.optimizer.lower() == 'srradam': #NOTE: need to double-check this iter_count = 1 optimizer = SRRAdam(model.parameters(), lr=args.lr, betas=(args.beta1, args.beta2), iter_count=iter_count, weight_decay=args.weight_decay, warmup=args.warmup, restarting_iter=args.restart_schedule[0]) # Resume title = 'cifar-10-' + args.arch logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title) logger.set_names([ 'Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.', 'Valid Acc.' ]) schedule_index = 1 # Resume title = '%s-' % args.dataset + args.arch if args.resume: # Load checkpoint. print('==> Resuming from checkpoint..') assert os.path.isfile( args.resume), 'Error: no checkpoint directory found!' # args.checkpoint = os.path.dirname(args.resume) checkpoint = torch.load(args.resume) best_acc = checkpoint['best_acc'] start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) if args.optimizer.lower() == 'srsgd' or args.optimizer.lower( ) == 'sradam' or args.optimizer.lower( ) == 'sradamw' or args.optimizer.lower() == 'srradam': iter_count = optimizer.param_groups[0]['iter_count'] # schedule_index = checkpoint['schedule_index'] schedule_index = 3 state['lr'] = optimizer.param_groups[0]['lr'] logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title, resume=True) else: logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title) logger.set_names([ 'Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.', 'Valid Acc.' ]) if args.evaluate: print('\nEvaluation only') test_loss, test_acc = test(testloader, model, criterion, start_epoch, use_cuda) print(' Test Loss: %.8f, Test Acc: %.2f' % (test_loss, test_acc)) return # Train and val for epoch in range(start_epoch, args.epochs): if args.optimizer.lower() == 'srsgd': if epoch == 161: start_decay_restarting_iter = args.restart_schedule[ schedule_index] - 1 current_lr = args.lr * (args.gamma**schedule_index) if epoch in args.schedule: current_lr = args.lr * (args.gamma**schedule_index) current_restarting_iter = args.restart_schedule[schedule_index] optimizer = SGD_Adaptive( model.parameters(), lr=current_lr, weight_decay=args.weight_decay, iter_count=iter_count, restarting_iter=current_restarting_iter) schedule_index += 1 if epoch >= 161: current_restarting_iter = start_decay_restarting_iter * ( args.epochs - epoch - 1) / (args.epochs - 162) + 1 optimizer = SGD_Adaptive( model.parameters(), lr=current_lr, weight_decay=args.weight_decay, iter_count=iter_count, restarting_iter=current_restarting_iter) else: adjust_learning_rate(optimizer, epoch) logger.file.write('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.epochs, state['lr'])) if args.optimizer.lower() == 'srsgd' or args.optimizer.lower( ) == 'sradam' or args.optimizer.lower( ) == 'sradamw' or args.optimizer.lower() == 'srradam': train_loss, train_acc, iter_count = train(trainloader, model, criterion, optimizer, epoch, use_cuda, logger) else: train_loss, train_acc = train(trainloader, model, criterion, optimizer, epoch, use_cuda, logger) test_loss, test_acc = test(testloader, model, criterion, epoch, use_cuda, logger) # append logger file logger.append( [state['lr'], train_loss, test_loss, train_acc, test_acc]) writer.add_scalars('train_loss', {args.model_name: train_loss}, epoch) writer.add_scalars('test_loss', {args.model_name: test_loss}, epoch) writer.add_scalars('train_acc', {args.model_name: train_acc}, epoch) writer.add_scalars('test_acc', {args.model_name: test_acc}, epoch) # save model is_best = test_acc > best_acc best_acc = max(test_acc, best_acc) save_checkpoint( { 'epoch': epoch + 1, 'schedule_index': schedule_index, 'state_dict': model.state_dict(), 'acc': test_acc, 'best_acc': best_acc, 'optimizer': optimizer.state_dict(), }, is_best, epoch, checkpoint=args.checkpoint) logger.file.write('Best acc:%f' % best_acc) logger.close() logger.plot() savefig(os.path.join(args.checkpoint, 'log.eps')) print('Best acc:') print(best_acc) with open("./all_results.txt", "a") as f: fcntl.flock(f, fcntl.LOCK_EX) f.write("%s\n" % args.checkpoint) f.write("best_acc %f\n\n" % best_acc) fcntl.flock(f, fcntl.LOCK_UN)
def train( train_data, exp_dir=datetime.now().strftime("corrector_model/%Y-%m-%d_%H%M"), learning_rate=0.00005, rsize=10, epochs=1, checkpoint_path='', seed=6548, batch_size=4, edge_loss=False, model_type='cnet', model_cap='normal', optimizer_type='radam', reset_optimizer=True, # if true, does not load optimizer chekcpoints safe_descent=True, activation_type='mish', activation_args={}, io=None, dynamic_lr=True, dropout=0, rotations=False, use_batch_norm=True, batch_norm_momentum=None, batch_norm_affine=True, use_gc=True, no_lr_schedule=False, diff_features_only=False): start_time = time.time() io.cprint("-------------------------------------------------------" + "\nexport dir = " + '/checkpoints/' + exp_dir + "\nbase_learning_rate = " + str(learning_rate) + "\nuse_batch_norm = " + str(use_batch_norm) + "\nbatch_norm_momentum = " + str(batch_norm_momentum) + "\nbatch_norm_affine = " + str(batch_norm_affine) + "\nno_lr_schedule = " + str(no_lr_schedule) + "\nuse_gc = " + str(use_gc) + "\nrsize = " + str(rsize) + "\npython_version: " + sys.version + "\ntorch_version: " + torch.__version__ + "\nnumpy_version: " + np.version.version + "\nmodel_type: " + model_type + "\nmodel_cap: " + model_cap + "\noptimizer: " + optimizer_type + "\nactivation_type: " + activation_type + "\nsafe_descent: " + str(safe_descent) + "\ndynamic_lr: " + str(dynamic_lr) + "\nrotations: " + str(rotations) + "\nepochs = " + str(epochs) + (("\ncheckpoint = " + checkpoint_path) if (checkpoint_path != None and checkpoint_path != '') else '') + "\nseed = " + str(seed) + "\nbatch_size = " + str(batch_size) + "\n#train_data = " + str(sum([bin.size(0) for bin in train_data["train_bins"]])) + "\n#test_data = " + str(len(train_data["test_samples"])) + "\n#validation_data = " + str(len(train_data["val_samples"])) + "\nedge_loss = " + str(edge_loss) + "\n-------------------------------------------------------" + "\nstart_time: " + datetime.now().strftime("%Y-%m-%d_%H%M%S") + "\n-------------------------------------------------------") # initialize torch & cuda --------------------------------------------------------------------- torch.manual_seed(seed) np.random.seed(seed) device = utils.getDevice(io) # extract train- & test data (and move to device) -------------------------------------------- # train_bins = [bin.float().to(device) for bin in train_data["train_bins"]] # test_samples = [sample.float().to(device) for sample in train_data["test_samples"]] # val_samples = [sample.float().to(device) for sample in train_data["val_samples"]] train_bins = [bin.float() for bin in train_data["train_bins"]] test_samples = [sample.float() for sample in train_data["test_samples"]] val_samples = [sample.float() for sample in train_data["val_samples"]] # Initialize Model ------------------------------------------------------------------------------ model_args = { 'model_type': model_type, 'model_cap': model_cap, 'input_channels': test_samples[0].size(1), 'output_channels': test_samples[0].size(1), 'rsize': rsize, 'emb_dims': 1024, 'activation_type': activation_type, 'activation_args': activation_args, 'dropout': dropout, 'batch_norm': use_batch_norm, 'batch_norm_affine': batch_norm_affine, 'batch_norm_momentum': batch_norm_momentum, 'diff_features_only': diff_features_only } model = getModel(model_args).to(device) # init optimizer & scheduler ------------------------------------------------------------------- lookahead_sync_period = 6 optimizer = None if optimizer_type == 'radam': optimizer = RAdam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-8, use_gc=use_gc) elif optimizer_type == 'lookahead': optimizer = Ranger(model.parameters(), lr=learning_rate, alpha=0.9, k=lookahead_sync_period) # make sure that either a LR schedule is given or dynamic LR is enabled assert dynamic_lr or not no_lr_schedule scheduler = None if no_lr_schedule else MultiplicativeLR( optimizer, lr_lambda=MultiplicativeAnnealing(epochs)) # set train settings & load previous model state ------------------------------------------------------------ checkpoint = getEmptyCheckpoint() last_epoch = 0 if (checkpoint_path != None and checkpoint_path != ''): checkpoint = torch.load(checkpoint_path) model.load_state_dict(checkpoint['model_state_dict'][-1]) if not reset_optimizer: optimizer.load_state_dict(checkpoint['optimizer_state_dict'][-1]) last_epoch = len(checkpoint['model_state_dict']) print('> loaded checkpoint! (%d epochs)' % (last_epoch)) checkpoint['train_settings'].append({ 'learning_rate': learning_rate, 'scheduler': scheduler, 'epochs': epochs, 'seed': seed, 'batch_size': batch_size, 'edge_loss': edge_loss, 'optimizer': optimizer_type, 'safe_descent:': str(safe_descent), 'dynamic_lr': str(dynamic_lr), 'rotations': str(rotations), 'train_data_count': sum([bin.size(0) for bin in train_data["train_bins"]]), 'test_data_count': len(train_data["test_samples"]), 'validation_data_count': len(train_data["val_samples"]), 'model_args': model_args }) # set up report interval (for logging) and batch size ------------------------------------------------------------------- report_interval = 100 loss_function = torch.nn.MSELoss(reduction='mean') # begin training ########################################################################################################################### io.cprint("\nBeginning Training..\n") for epoch in range(last_epoch + 1, last_epoch + epochs + 1): io.cprint( "Epoch: %d ------------------------------------------------------------------------------------------" % (epoch)) io.cprint("Current LR: %.10f" % (optimizer.param_groups[0]['lr'])) model.train() optimizer.zero_grad() checkpoint['train_batch_loss'].append([]) checkpoint['train_batch_N'].append([]) checkpoint['train_batch_lr_adjust'].append([]) checkpoint['train_batch_loss_reduction'].append([]) checkpoint['lr'].append(optimizer.param_groups[0]['lr']) # draw random batches from random bins binbatches = utils.drawBinBatches([bin.size(0) for bin in train_bins], batchsize=batch_size) checkpoint['train_batch_N'][-1] = [ train_bins[bin_id][batch_ids].size(1) for (bin_id, batch_ids) in binbatches ] failed_loss_optims = 0 cum_lr_adjust_fac = 0 cum_loss_reduction = 0 # pre-compute random rotations if needed batch_rotations = [None] * len(binbatches) if rotations: start_rotations = time.time() batch_rotations = torch.zeros( (len(binbatches), batch_size, test_samples[0].size(1), test_samples[0].size(1)), device=device) for i in range(len(binbatches)): for j in range(batch_size): batch_rotations[i, j] = utils.getRandomRotation( test_samples[0].size(1), device=device) print("created batch rotations (%ds)" % (time.time() - start_rotations)) b = 0 # batch counter train_start = time.time() for (bin_id, batch_ids) in binbatches: b += 1 # print ("handling batch %d" % (b)) # prediction & loss ---------------------------------------- batch_sample = train_bins[bin_id][batch_ids].to( model.base.device) # size: (B x N x d x 2) batch_loss = getBatchLoss(model, batch_sample, loss_function, edge_loss=edge_loss, rotations=batch_rotations[b - 1]) batch_loss.backward() checkpoint['train_batch_loss'][-1].append(batch_loss.item()) new_loss = 0.0 lr_adjust = 1.0 loss_reduction = 0.0 # if safe descent is enabled, try to optimize the descent step so that a reduction in loss is guaranteed if safe_descent: # create backups to restore states before the optimizer step model_state_backup = copy.deepcopy(model.state_dict()) opt_state_backup = copy.deepcopy(optimizer.state_dict()) # make an optimizer step optimizer.step() # in each itearation, check if the optimzer gave an improvement # if not, restore the original states, reduce the learning rate and try again # no gradient needed for the plain loss calculation with torch.no_grad(): for i in range(10): new_loss = getBatchLoss( model, batch_sample, loss_function, edge_loss=edge_loss, rotations=batch_rotations[b - 1]).item() # if the model performs better now we continue, if not we try a smaller learning step if (new_loss < batch_loss.item()): # print("lucky! (%f -> %f) reduction: %.4f%%" % (batch_loss.item(), new_loss, 100 * (batch_loss.item()-new_loss) / batch_loss.item())) break else: # print("try again.. (%f -> %f)" % (batch_loss.item(), new_loss)) model.load_state_dict(model_state_backup) optimizer.load_state_dict(opt_state_backup) lr_adjust *= 0.7 optimizer.step(lr_adjust=lr_adjust) loss_reduction = 100 * (batch_loss.item() - new_loss) / batch_loss.item() if new_loss >= batch_loss.item(): failed_loss_optims += 1 else: cum_lr_adjust_fac += lr_adjust cum_loss_reduction += loss_reduction else: cum_lr_adjust_fac += lr_adjust optimizer.step() checkpoint['train_batch_lr_adjust'][-1].append(lr_adjust) checkpoint['train_batch_loss_reduction'][-1].append(loss_reduction) # reset gradients optimizer.zero_grad() # statistic caluclation and output ------------------------- if b % report_interval == 0: last_100_loss = sum(checkpoint['train_batch_loss'][-1] [b - report_interval:b]) / report_interval improvement_indicator = '+' if epoch > 1 and last_100_loss < checkpoint[ 'train_loss'][-1] else '' io.cprint( ' Batch %4d to %4d | loss: %.10f%1s | av. dist. per neighbor: %.10f | E%3d | T:%5ds | Failed Optims: %3d (%05.2f%%) | Av. Adjust LR: %.6f | Av. Loss Reduction: %07.4f%%' % (b - (report_interval - 1), b, last_100_loss, improvement_indicator, np.sqrt(last_100_loss), epoch, time.time() - train_start, failed_loss_optims, 100 * (failed_loss_optims / report_interval), (cum_lr_adjust_fac / (report_interval - failed_loss_optims) if failed_loss_optims < report_interval else -1), (cum_loss_reduction / (report_interval - failed_loss_optims) if failed_loss_optims < report_interval else -1))) failed_loss_optims = 0 cum_lr_adjust_fac = 0 cum_loss_reduction = 0 checkpoint['train_loss'].append( sum(checkpoint['train_batch_loss'][-1]) / b) checkpoint['train_time'].append(time.time() - train_start) io.cprint( '----\n TRN | time: %5ds | loss: %.10f| av. dist. per neighbor: %.10f' % (checkpoint['train_time'][-1], checkpoint['train_loss'][-1], np.sqrt(checkpoint['train_loss'][-1]))) torch.cuda.empty_cache() #################### # Test & Validation #################### with torch.no_grad(): if use_batch_norm: model.eval_bn() eval_bn_start = time.time() # run through all train samples again to accumulate layer-wise input distribution statistics (mean and variance) with fixed weights # these statistics are later used for the BatchNorm layers during inference for (bin_id, batch_ids) in binbatches: input = train_bins[bin_id][batch_ids][:, :, :, 0].squeeze( -1) # size: (B x N x d) model(input.transpose(1, 2).to(model.base.device)).transpose( 1, 2) # size: (B x N x d) io.cprint('Accumulated BN Layer statistics (%ds)' % (time.time() - eval_bn_start)) model.eval() test_start = time.time() test_loss = getTestLoss(model, test_samples, loss_function, edge_loss=edge_loss) checkpoint['test_loss'].append(test_loss) checkpoint['test_time'].append(time.time() - test_start) io.cprint( ' TST | time: %5ds | loss: %.10f| av. dist. per neighbor: %.10f' % (checkpoint['test_time'][-1], checkpoint['test_loss'][-1], np.sqrt(checkpoint['test_loss'][-1]))) val_start = time.time() val_loss = getTestLoss(model, val_samples, loss_function, edge_loss=edge_loss) checkpoint['val_loss'].append(val_loss) checkpoint['val_time'].append(time.time() - val_start) io.cprint( ' VAL | time: %5ds | loss: %.10f| av. dist. per neighbor: %.10f' % (checkpoint['val_time'][-1], checkpoint['val_loss'][-1], np.sqrt(checkpoint['val_loss'][-1]))) #################### # Scheduler Step #################### if not no_lr_schedule: scheduler.step() if epoch > 1 and dynamic_lr and sum( checkpoint['train_batch_lr_adjust'][-1]) > 0: io.cprint("----\n dynamic lr adjust: %.10f" % (0.5 * (1 + sum(checkpoint['train_batch_lr_adjust'][-1]) / len(checkpoint['train_batch_lr_adjust'][-1])))) for param_group in optimizer.param_groups: param_group['lr'] *= 0.5 * ( 1 + sum(checkpoint['train_batch_lr_adjust'][-1]) / len(checkpoint['train_batch_lr_adjust'][-1])) # Save model and optimizer state .. checkpoint['model_state_dict'].append(copy.deepcopy( model.state_dict())) checkpoint['optimizer_state_dict'].append( copy.deepcopy(optimizer.state_dict())) torch.save(checkpoint, exp_dir + '/corrector_checkpoints.t7') io.cprint("\n-------------------------------------------------------" + ("\ntotal_time: %.2fh" % ((time.time() - start_time) / 3600)) + ("\ntrain_time: %.2fh" % (sum(checkpoint['train_time']) / 3600)) + ("\ntest_time: %.2fh" % (sum(checkpoint['test_time']) / 3600)) + ("\nval_time: %.2fh" % (sum(checkpoint['val_time']) / 3600)) + "\n-------------------------------------------------------" + "\nend_time: " + datetime.now().strftime("%Y-%m-%d_%H%M%S") + "\n-------------------------------------------------------")
def main(): global best_acc start_epoch = args.start_epoch # start from epoch 0 or last checkpoint epoch if not os.path.isdir(args.checkpoint): mkdir_p(args.checkpoint) # Data print('==> Preparing dataset %s' % args.dataset) transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) if args.dataset == 'cifar10': dataloader = datasets.CIFAR10 num_classes = 10 else: dataloader = datasets.CIFAR100 num_classes = 100 trainset = dataloader(root='./data', train=True, download=True, transform=transform_train) trainloader = data.DataLoader(trainset, batch_size=args.train_batch, shuffle=True, num_workers=args.workers) testset = dataloader(root='./data', train=False, download=False, transform=transform_test) testloader = data.DataLoader(testset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers) # Model print("==> creating model '{}'".format(args.arch)) if args.arch.startswith('resnext'): model = models.__dict__[args.arch]( cardinality=args.cardinality, num_classes=num_classes, depth=args.depth, widen_factor=args.widen_factor, dropRate=args.drop, ) elif args.arch.startswith('densenet'): model = models.__dict__[args.arch]( num_classes=num_classes, depth=args.depth, growthRate=args.growthRate, compressionRate=args.compressionRate, dropRate=args.drop, ) elif args.arch.startswith('wrn'): model = models.__dict__[args.arch]( num_classes=num_classes, depth=args.depth, widen_factor=args.widen_factor, dropRate=args.drop, ) elif args.arch.endswith('resnet'): model = models.__dict__[args.arch]( num_classes=num_classes, depth=args.depth, block_name=args.block_name, ) else: model = models.__dict__[args.arch](num_classes=num_classes) model = torch.nn.DataParallel(model).cuda() cudnn.benchmark = True print(' Total params: %.2fM' % (sum(p.numel() for p in model.parameters()) / 1000000.0)) criterion = nn.CrossEntropyLoss() if args.optimizer.lower() == 'sgd': optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # elif args.optimizer.lower() == 'adam': # optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(args.beta1, args.beta2), weight_decay=args.weight_decay) elif args.optimizer.lower() == 'radam': optimizer = RAdam(model.parameters(), lr=args.lr, betas=(args.beta1, args.beta2), weight_decay=args.weight_decay) elif args.optimizer.lower() == 'adamw': optimizer = AdamW(model.parameters(), lr=args.lr, betas=(args.beta1, args.beta2), weight_decay=args.weight_decay, warmup=args.warmup) # Resume title = 'cifar-10-' + args.arch # if args.resume: # # Load checkpoint. # print('==> Resuming from checkpoint..') # assert os.path.isfile(args.resume), 'Error: no checkpoint directory found!' # args.checkpoint = os.path.dirname(args.resume) # checkpoint = torch.load(args.resume) # best_acc = checkpoint['best_acc'] # start_epoch = checkpoint['epoch'] # model.load_state_dict(checkpoint['state_dict']) # optimizer.load_state_dict(checkpoint['optimizer']) # logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title, resume=True) # else: logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title) logger.set_names([ 'Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.', 'Valid Acc.' ]) if args.evaluate: print('\nEvaluation only') test_loss, test_acc = test(testloader, model, criterion, start_epoch, use_cuda) print(' Test Loss: %.8f, Test Acc: %.2f' % (test_loss, test_acc)) return # Train and val for epoch in range(start_epoch, args.epochs): adjust_learning_rate(optimizer, epoch) print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.epochs, state['lr'])) train_loss, train_acc = train(trainloader, model, criterion, optimizer, epoch, use_cuda) test_loss, test_acc = test(testloader, model, criterion, epoch, use_cuda) # append logger file logger.append( [state['lr'], train_loss, test_loss, train_acc, test_acc]) # writer.add_scalars('loss_tracking/train_loss', {args.model_name: train_loss}, epoch) # writer.add_scalars('loss_tracking/test_loss', {args.model_name: test_loss}, epoch) # writer.add_scalars('loss_tracking/train_acc', {args.model_name: train_acc}, epoch) # writer.add_scalars('loss_tracking/test_acc', {args.model_name: test_acc}, epoch) # save model is_best = test_acc > best_acc best_acc = max(test_acc, best_acc) save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'acc': test_acc, 'best_acc': best_acc, 'optimizer': optimizer.state_dict(), }, is_best, checkpoint=args.checkpoint) logger.close() logger.plot() savefig(os.path.join(args.checkpoint, 'log.eps')) print('Best acc:') print(best_acc)
def main_worker(gpu, ngpus_per_node, args, writer): global best_f1 args.gpu = gpu if args.gpu is not None: print("Use GPU: {} for training".format(args.gpu)) if args.distributed: if args.dist_url == "env://" and args.rank == -1: args.rank = int(os.environ["RANK"]) if args.multiprocessing_distributed: # For multiprocessing distributed training, rank needs to be the # global rank among all the processes args.rank = args.rank * ngpus_per_node + gpu dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank) # create model if args.pretrained: print("=> using pre-trained model '{}'".format(args.arch)) model = models.__dict__[args.arch](pretrained=True) model.fc = nn.Linear(512 * 1, 4) # replace final classifier else: print("=> creating model '{}'".format(args.arch)) if args.arch == 'residual_attention_network': from model.residual_attention_network import ResidualAttentionModel_92 model = ResidualAttentionModel_92(num_classes=4) else: model = models.__dict__[args.arch](num_classes=4) if not torch.cuda.is_available(): print('using CPU, this will be slow') elif args.distributed: # For multiprocessing distributed, DistributedDataParallel constructor # should always set the single device scope, otherwise, # DistributedDataParallel will use all available devices. if args.gpu is not None: torch.cuda.set_device(args.gpu) model.cuda(args.gpu) # When using a single GPU per process and per # DistributedDataParallel, we need to divide the batch size # ourselves based on the total number of GPUs we have args.batch_size = int(args.batch_size / ngpus_per_node) args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) else: model.cuda() # DistributedDataParallel will divide and allocate batch_size to all # available GPUs if device_ids are not set model = torch.nn.parallel.DistributedDataParallel(model) elif args.gpu is not None: torch.cuda.set_device(args.gpu) model = model.cuda(args.gpu) else: # DataParallel will divide and allocate batch_size to all available GPUs if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): model.features = torch.nn.DataParallel(model.features) model.cuda() else: model = torch.nn.DataParallel(model).cuda() # define loss function (criterion) and optimizer criterion = nn.CrossEntropyLoss(reduction='none').cuda(args.gpu) # Better Adam optimizer: https://github.com/LiyuanLucasLiu/RAdam optimizer = RAdam(model.parameters(),lr=args.lr) # optionally resume from a checkpoint checkpoint = None if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) if args.gpu is None: checkpoint = torch.load(args.resume) else: # Map model to be loaded to specified single gpu. loc = 'cuda:{}'.format(args.gpu) checkpoint = torch.load(args.resume, map_location=loc) args.start_epoch = checkpoint['epoch'] best_f1 = checkpoint['best_f1'] if args.gpu is not None: # best_acc1 may be from a checkpoint from a different GPU best_f1 = best_f1.to(args.gpu) model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) cudnn.benchmark = True # Data loading code train_data = args.train_data val_data = args.val_data if not args.evaluate: if not (os.path.isfile(train_data) and os.path.splitext(train_data)[-1] == '.csv'): RoofImages.to_csv_datasource(train_data,csv_filename='tmp_train_set.csv', calc_perf=True) traindir = 'tmp_train_set.csv' else: traindir = args.train_data if not (os.path.isfile(val_data) and os.path.splitext(val_data)[-1] == '.csv'): RoofImages.to_csv_datasource(val_data,csv_filename='tmp_val_set.csv', calc_perf=True) valdir = 'tmp_val_set.csv' else: valdir = args.val_data else: if not args.resume: print('Evaluation is chosen without resuming from a checkpoint. Please choose a checkpoint to load with the --resume parameter.') exit(1) if not (os.path.isfile(val_data) and os.path.splitext(val_data)[-1] == '.csv'): RoofImages.to_csv_datasource(val_data,csv_filename='tmp_val_set.csv', calc_perf=True) valdir = 'tmp_val_set.csv' else: valdir = args.val_data normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) if not args.evaluate: train_dataset = RoofImages( traindir, transforms.Compose([ transforms.RandomResizedCrop(224), transforms.ColorJitter(0.4,0.4,0.4,0.4), transforms.RandomGrayscale(), transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip(), transforms.ToTensor(), normalize, ])) if args.distributed: if args.weighted_sampling: print ('Warning: Weighted sampling not implemented for distributed training. So no weighted sampling will be performed') train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) else: if args.weighted_sampling: train_weights = np.array(train_dataset.train_weights) train_sampler = torch.utils.data.WeightedRandomSampler(train_weights, len(train_weights), replacement=True) else: train_sampler = None train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler) try: classes = checkpoint['classes'] except: classes = None val_loader = torch.utils.data.DataLoader( RoofImages(valdir, transforms.Compose([ transforms.Resize(256 if args.val_resize is None else args.val_resize), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ]),test_mode=args.evaluate, classes=classes), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) if args.evaluate: f1, results = validate(val_loader, model, criterion, args, 0, writer, test_mode=True) results.to_csv(args.result_file) return for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) adjust_learning_rate(optimizer, epoch, args) # train for one epoch idx_vec, loss_vec = train(train_loader, model, criterion, optimizer, epoch, args, writer) if args.loss_weighting: print ('Weighting losses') train_set_order = np.argsort(idx_vec) loss_vec = (loss_vec - loss_vec.min())/(loss_vec.max() - loss_vec.min()) loss_vec_in_order = loss_vec[train_set_order] train_sampler.weights = torch.as_tensor(loss_vec_in_order, dtype=torch.double) # evaluate on validation set f1, results = validate(val_loader, model, criterion, args, epoch, writer) results.to_csv(args.result_file) # remember best acc@1 and save checkpoint is_best = f1 > best_f1 best_f1 = max(f1, best_f1) if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0): save_checkpoint({ 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'best_f1': best_f1, 'optimizer' : optimizer.state_dict(), 'classes' : train_dataset.classes # Save the exact classnames this checkpoint was created with }, is_best=is_best, file_folder = args.log_dir)
def train(train_data, exp_dir=datetime.now().strftime("detector_model/%Y-%m-%d_%H%M"), learning_rate=0.00005, rsize=10, epochs=1, checkpoint_path='', seed=6548, batch_size=4, model_type='cnet', model_cap='normal', optimizer='radam', safe_descent=True, activation_type='mish', activation_args={}, io=None, dynamic_lr=True, dropout=0, rotations=False, use_batch_norm=True, batch_norm_momentum=None, batch_norm_affine=True, use_gc=True, no_lr_schedule=False, diff_features_only=False, scale_min=1, scale_max=1, noise=0): start_time = time.time() scale_min = scale_min if scale_min < 1 else 1 scale_max = scale_max if scale_max > 1 else 1 io.cprint("-------------------------------------------------------" + "\nexport dir = " + '/checkpoints/' + exp_dir + "\nbase_learning_rate = " + str(learning_rate) + "\nuse_batch_norm = " + str(use_batch_norm) + "\nbatch_norm_momentum = " + str(batch_norm_momentum) + "\nbatch_norm_affine = " + str(batch_norm_affine) + "\nno_lr_schedule = " + str(no_lr_schedule) + "\nuse_gc = " + str(use_gc) + "\nrsize = " + str(rsize) + "\npython_version: " + sys.version + "\ntorch_version: " + torch.__version__ + "\nnumpy_version: " + np.version.version + "\nmodel_type: " + model_type + "\nmodel_cap: " + model_cap + "\noptimizer: " + optimizer + "\nactivation_type: " + activation_type + "\nsafe_descent: " + str(safe_descent) + "\ndynamic_lr: " + str(dynamic_lr) + "\nrotations: " + str(rotations) + "\nscaling: " + str(scale_min) + " to " + str(scale_max) + "\nnoise: " + str(noise) + "\nepochs = " + str(epochs) + (("\ncheckpoint = " + checkpoint_path) if checkpoint_path != '' else '') + "\nseed = " + str(seed) + "\nbatch_size = " + str(batch_size) + "\n#train_data = " + str(sum([bin.size(0) for bin in train_data["train_bins"]])) + "\n#test_data = " + str(len(train_data["test_samples"])) + "\n#validation_data = " + str(len(train_data["val_samples"])) + "\n-------------------------------------------------------" + "\nstart_time: " + datetime.now().strftime("%Y-%m-%d_%H%M%S") + "\n-------------------------------------------------------") # initialize torch & cuda --------------------------------------------------------------------- torch.manual_seed(seed) np.random.seed(seed) device = utils.getDevice(io) # extract train- & test data (and move to device) -------------------------------------------- pts = train_data["pts"].to(device) val_pts = train_data["val_pts"].to(device) train_bins = train_data["train_bins"] test_samples = train_data["test_samples"] val_samples = train_data["val_samples"] # the maximum noise offset for each point is equal to the distance to its nearest neighbor max_noise = torch.square(pts[train_data["knn"][:, 0]] - pts).sum(dim=1).sqrt() # Initialize Model ------------------------------------------------------------------------------ model_args = { 'model_type': model_type, 'model_cap': model_cap, 'input_channels': pts.size(1), 'output_channels': 2, 'rsize': rsize, 'emb_dims': 1024, 'activation_type': activation_type, 'activation_args': activation_args, 'dropout': dropout, 'batch_norm': use_batch_norm, 'batch_norm_affine': batch_norm_affine, 'batch_norm_momentum': batch_norm_momentum, 'diff_features_only': diff_features_only } model = getModel(model_args).to(device) # init optimizer & scheduler ------------------------------------------------------------------- lookahead_sync_period = 6 opt = None if optimizer == 'radam': opt = RAdam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-8, use_gc=use_gc) elif optimizer == 'lookahead': opt = Ranger(model.parameters(), lr=learning_rate, alpha=0.9, k=lookahead_sync_period) # make sure that either a LR schedule is given or dynamic LR is enabled assert dynamic_lr or not no_lr_schedule scheduler = None if no_lr_schedule else MultiplicativeLR( opt, lr_lambda=MultiplicativeAnnealing(epochs)) # set train settings & load previous model state ------------------------------------------------------------ checkpoint = getEmptyCheckpoint() last_epoch = 0 if (checkpoint_path != ''): checkpoint = torch.load(checkpoint_path) model.load_state_dict(checkpoint['model_state_dict'][-1]) optimizer.load_state_dict(checkpoint['optimizer_state_dict'][-1]) last_epoch = len(checkpoint['model_state_dict']) print('> loaded checkpoint! (%d epochs)' % (last_epoch)) checkpoint['train_settings'].append({ 'learning_rate': learning_rate, 'scheduler': scheduler, 'epochs': epochs, 'seed': seed, 'batch_size': batch_size, 'optimizer': optimizer, 'safe_descent:': str(safe_descent), 'dynamic_lr': str(dynamic_lr), 'rotations': str(rotations), 'scale_min': scale_min, 'scale_max': scale_max, 'noise': noise, 'train_data_count': sum([bin.size(0) for bin in train_data["train_bins"]]), 'test_data_count': len(train_data["test_samples"]), 'validation_data_count': len(train_data["val_samples"]), 'model_args': model_args }) # calculate class weights --------------------------------------------------------------------- av_c1_freq = sum([ torch.sum(bin[:, :, 1]).item() for bin in train_data["train_bins"] ]) / sum([bin[:, :, 1].numel() for bin in train_data["train_bins"]]) class_weights = torch.tensor([av_c1_freq, 1 - av_c1_freq]).float().to(device) io.cprint("\nC0 Weight: %.4f" % (class_weights[0].item())) io.cprint("C1 Weight: %.4f" % (class_weights[1].item())) # Adjust Weights in favor of C1 (edge:true class) # class_weights[0] = class_weights[0] / 2 # class_weights[1] = 1 - class_weights[0] # io.cprint("\nAdjusted C0 Weight: %.4f" % (class_weights[0].item())) # io.cprint("Adjusted C1 Weight: %.4f" % (class_weights[1].item())) # set up report interval (for logging) and batch size ------------------------------------------------------------------- report_interval = 100 # begin training ########################################################################################################################### io.cprint("\nBeginning Training..\n") for epoch in range(last_epoch + 1, last_epoch + epochs + 1): io.cprint( "Epoch: %d ------------------------------------------------------------------------------------------" % (epoch)) io.cprint("Current LR: %.10f" % (opt.param_groups[0]['lr'])) model.train() opt.zero_grad() checkpoint['train_batch_loss'].append([]) checkpoint['train_batch_N'].append([]) checkpoint['train_batch_acc'].append([]) checkpoint['train_batch_C0_acc'].append([]) checkpoint['train_batch_C1_acc'].append([]) checkpoint['train_batch_lr_adjust'].append([]) checkpoint['train_batch_loss_reduction'].append([]) checkpoint['lr'].append(opt.param_groups[0]['lr']) # draw random batches from random bins binbatches = utils.drawBinBatches([bin.size(0) for bin in train_bins], batchsize=batch_size) checkpoint['train_batch_N'][-1] = [ train_bins[bin_id][batch_ids].size(1) for (bin_id, batch_ids) in binbatches ] failed_loss_optims = 0 cum_lr_adjust_fac = 0 cum_loss_reduction = 0 # pre-compute random rotations if needed batch_rotations = [None] * len(binbatches) if rotations: start_rotations = time.time() batch_rotations = torch.zeros( (len(binbatches), batch_size, pts.size(1), pts.size(1)), device=device) for i in range(len(binbatches)): for j in range(batch_size): batch_rotations[i, j] = utils.getRandomRotation(pts.size(1), device=device) print("created batch rotations (%ds)" % (time.time() - start_rotations)) b = 0 # batch counter train_start = time.time() for (bin_id, batch_ids) in binbatches: b += 1 batch_pts_ids = train_bins[bin_id][batch_ids][:, :, 0] # size: (B x N) batch_input = pts[batch_pts_ids] # size: (B x N x d) batch_target = train_bins[bin_id][batch_ids][:, :, 1].to( device) # size: (B x N) if batch_rotations[b - 1] != None: batch_input = batch_input.matmul(batch_rotations[b - 1]) if noise > 0: noise_v = torch.randn( batch_input.size(), device=batch_input.device) # size: (B x N x d) noise_v.div_( torch.square(noise_v).sum( dim=2).sqrt()[:, :, None]) # norm to unit vectors batch_input.addcmul(noise_v, max_noise[batch_pts_ids][:, :, None], value=noise) if scale_min < 1 or scale_max > 1: # batch_scales = scale_min + torch.rand(batch_input.size(0), device=batch_input.device) * (scale_max - scale_min) batch_scales = torch.rand(batch_input.size(0), device=batch_input.device) batch_scales.mul_(scale_max - scale_min) batch_scales.add_(scale_min) batch_input.mul(batch_scales[:, None, None]) batch_input = batch_input.transpose(1, 2) # size: (B x d x N) # prediction & loss ---------------------------------------- batch_prediction = model(batch_input).transpose( 1, 2) # size: (B x N x 2) batch_loss = cross_entropy(batch_prediction.reshape(-1, 2), batch_target.view(-1), class_weights, reduction='mean') batch_loss.backward() checkpoint['train_batch_loss'][-1].append(batch_loss.item()) new_loss = 0.0 lr_adjust = 1.0 loss_reduction = 0.0 # if safe descent is enabled, try to optimize the descent step so that a reduction in loss is guaranteed if safe_descent: # create backups to restore states before the optimizer step model_state_backup = copy.deepcopy(model.state_dict()) opt_state_backup = copy.deepcopy(opt.state_dict()) # make an optimizer step opt.step() # in each itearation, check if the optimzer gave an improvement # if not, restore the original states, reduce the learning rate and try again # no gradient needed for the plain loss calculation with torch.no_grad(): for i in range(10): # new_batch_prediction = model(batch_input).transpose(1,2).contiguous() new_batch_prediction = model(batch_input).transpose( 1, 2) new_loss = cross_entropy(new_batch_prediction.reshape( -1, 2), batch_target.view(-1), class_weights, reduction='mean').item() # if the model performs better now we continue, if not we try a smaller learning step if (new_loss < batch_loss.item()): # print("lucky! (%f -> %f) reduction: %.4f%%" % (batch_loss.item(), new_loss, 100 * (batch_loss.item()-new_loss) / batch_loss.item())) break else: # print("try again.. (%f -> %f)" % (batch_loss.item(), new_loss)) model.load_state_dict(model_state_backup) opt.load_state_dict(opt_state_backup) lr_adjust *= 0.7 opt.step(lr_adjust=lr_adjust) loss_reduction = 100 * (batch_loss.item() - new_loss) / batch_loss.item() if new_loss >= batch_loss.item(): failed_loss_optims += 1 else: cum_lr_adjust_fac += lr_adjust cum_loss_reduction += loss_reduction else: cum_lr_adjust_fac += lr_adjust opt.step() checkpoint['train_batch_lr_adjust'][-1].append(lr_adjust) checkpoint['train_batch_loss_reduction'][-1].append(loss_reduction) # reset gradients opt.zero_grad() # make class prediction and save stats ----------------------- success_vector = torch.argmax(batch_prediction, dim=2) == batch_target c0_idx = batch_target == 0 c1_idx = batch_target == 1 checkpoint['train_batch_acc'][-1].append( torch.sum(success_vector).item() / success_vector.numel()) checkpoint['train_batch_C0_acc'][-1].append( torch.sum(success_vector[c0_idx]).item() / torch.sum(c0_idx).item()) # TODO handle divsion by zero checkpoint['train_batch_C1_acc'][-1].append( torch.sum(success_vector[c1_idx]).item() / torch.sum(c1_idx).item()) # TODO # statistic caluclation and output ------------------------- if b % report_interval == 0: last_100_loss = sum(checkpoint['train_batch_loss'][-1] [b - report_interval:b]) / report_interval last_100_acc = sum(checkpoint['train_batch_acc'][-1] [b - report_interval:b]) / report_interval last_100_acc_c0 = sum( checkpoint['train_batch_C0_acc'][-1] [b - report_interval:b]) / report_interval last_100_acc_c1 = sum( checkpoint['train_batch_C1_acc'][-1] [b - report_interval:b]) / report_interval io.cprint( ' Batch %4d to %4d | loss: %.5f%1s| acc: %.4f%1s| C0 acc: %.4f%1s| C1 acc: %.4f%1s| E%3d | T:%5ds | Failed Optims: %3d (%05.2f%%) | Av. Adjust LR: %.6f | Av. Loss Reduction: %07.4f%%' % (b - (report_interval - 1), b, last_100_loss, '+' if epoch > 1 and last_100_loss < checkpoint['train_loss'][-1] else '', last_100_acc, '+' if epoch > 1 and last_100_acc > checkpoint['train_acc'][-1] else '', last_100_acc_c0, '+' if epoch > 1 and last_100_acc_c0 > checkpoint['train_C0_acc'][-1] else '', last_100_acc_c1, '+' if epoch > 1 and last_100_acc_c1 > checkpoint['train_C1_acc'][-1] else '', epoch, time.time() - train_start, failed_loss_optims, 100 * (failed_loss_optims / report_interval), (cum_lr_adjust_fac / (report_interval - failed_loss_optims) if failed_loss_optims < report_interval else -1), (cum_loss_reduction / (report_interval - failed_loss_optims) if failed_loss_optims < report_interval else -1))) failed_loss_optims = 0 cum_lr_adjust_fac = 0 cum_loss_reduction = 0 checkpoint['train_loss'].append( sum(checkpoint['train_batch_loss'][-1]) / b) checkpoint['train_acc'].append( sum(checkpoint['train_batch_acc'][-1]) / b) checkpoint['train_C0_acc'].append( sum(checkpoint['train_batch_C0_acc'][-1]) / b) checkpoint['train_C1_acc'].append( sum(checkpoint['train_batch_C1_acc'][-1]) / b) checkpoint['train_time'].append(time.time() - train_start) io.cprint( '----\n TRN | time: %5ds | loss: %.10f | acc: %.4f | C0 acc: %.4f | C1 acc: %.4f' % (checkpoint['train_time'][-1], checkpoint['train_loss'][-1], checkpoint['train_acc'][-1], checkpoint['train_C0_acc'][-1], checkpoint['train_C1_acc'][-1])) torch.cuda.empty_cache() #################### # Test & Validation #################### with torch.no_grad(): if use_batch_norm: model.eval_bn() eval_bn_start = time.time() # run through all train samples again to accumulate layer-wise input distribution statistics (mean and variance) with fixed weights # these statistics are later used for the BatchNorm layers during inference for (bin_id, batch_ids) in binbatches: batch_pts_ids = train_bins[bin_id][ batch_ids][:, :, 0] # size: (B xN) batch_input = pts[batch_pts_ids] # size: (B x N x d) # batch_input = batch_input.transpose(1,2).contiguous() # size: (B x d x N) batch_input = batch_input.transpose(1, 2) # size: (B x d x N) model(batch_input) io.cprint('Accumulated BN Layer statistics (%ds)' % (time.time() - eval_bn_start)) model.eval() if len(test_samples) > 0: test_start = time.time() test_loss, test_acc, test_acc_c0, test_acc_c1 = getTestLoss( pts, test_samples, model, class_weights) checkpoint['test_loss'].append(test_loss) checkpoint['test_acc'].append(test_acc) checkpoint['test_C0_acc'].append(test_acc_c0) checkpoint['test_C1_acc'].append(test_acc_c1) checkpoint['test_time'].append(time.time() - test_start) io.cprint( ' TST | time: %5ds | loss: %.10f | acc: %.4f | C0 acc: %.4f | C1 acc: %.4f' % (checkpoint['test_time'][-1], checkpoint['test_loss'][-1], checkpoint['test_acc'][-1], checkpoint['test_C0_acc'][-1], checkpoint['test_C1_acc'][-1])) else: io.cprint(' TST | n/a (no samples)') if len(val_samples) > 0: val_start = time.time() val_loss, val_acc, val_acc_c0, val_acc_c1 = getTestLoss( val_pts, val_samples, model, class_weights) checkpoint['val_loss'].append(val_loss) checkpoint['val_acc'].append(val_acc) checkpoint['val_C0_acc'].append(val_acc_c0) checkpoint['val_C1_acc'].append(val_acc_c1) checkpoint['val_time'].append(time.time() - val_start) io.cprint( ' VAL | time: %5ds | loss: %.10f | acc: %.4f | C0 acc: %.4f | C1 acc: %.4f' % (checkpoint['val_time'][-1], checkpoint['val_loss'][-1], checkpoint['val_acc'][-1], checkpoint['val_C0_acc'][-1], checkpoint['val_C1_acc'][-1])) else: io.cprint(' VAL | n/a (no samples)') #################### # Scheduler Step #################### if not no_lr_schedule: scheduler.step() if epoch > 1 and dynamic_lr and sum( checkpoint['train_batch_lr_adjust'][-1]) > 0: io.cprint("----\n dynamic lr adjust: %.10f" % (0.5 * (1 + sum(checkpoint['train_batch_lr_adjust'][-1]) / len(checkpoint['train_batch_lr_adjust'][-1])))) for param_group in opt.param_groups: param_group['lr'] *= 0.5 * ( 1 + sum(checkpoint['train_batch_lr_adjust'][-1]) / len(checkpoint['train_batch_lr_adjust'][-1])) # Save model and optimizer state .. checkpoint['model_state_dict'].append(copy.deepcopy( model.state_dict())) checkpoint['optimizer_state_dict'].append( copy.deepcopy(opt.state_dict())) torch.save(checkpoint, exp_dir + '/detector_checkpoints.t7') io.cprint("\n-------------------------------------------------------" + ("\ntotal_time: %.2fh" % ((time.time() - start_time) / 3600)) + ("\ntrain_time: %.2fh" % (sum(checkpoint['train_time']) / 3600)) + ("\ntest_time: %.2fh" % (sum(checkpoint['test_time']) / 3600)) + ("\nval_time: %.2fh" % (sum(checkpoint['val_time']) / 3600)) + "\n-------------------------------------------------------" + "\nend_time: " + datetime.now().strftime("%Y-%m-%d_%H%M%S") + "\n-------------------------------------------------------")
def train(args, log_dir, checkpoint_path, trainloader, testloader, tensorboard, c, model_name, ap, cuda=True, model_params=None): loss1_weight = c.train_config['loss1_weight'] use_mixup = False if 'mixup' not in c.model else c.model['mixup'] if use_mixup: mixup_alpha = 1 if 'mixup_alpha' not in c.model else c.model[ 'mixup_alpha'] mixup_augmenter = Mixup(mixup_alpha=mixup_alpha) print("Enable Mixup with alpha:", mixup_alpha) model = return_model(c, model_params) if c.train_config['optimizer'] == 'adam': optimizer = torch.optim.Adam( model.parameters(), lr=c.train_config['learning_rate'], weight_decay=c.train_config['weight_decay']) elif c.train_config['optimizer'] == 'adamw': optimizer = torch.optim.AdamW( model.parameters(), lr=c.train_config['learning_rate'], weight_decay=c.train_config['weight_decay']) elif c.train_config['optimizer'] == 'radam': optimizer = RAdam(model.parameters(), lr=c.train_config['learning_rate'], weight_decay=c.train_config['weight_decay']) else: raise Exception("The %s not is a optimizer supported" % c.train['optimizer']) step = 0 if checkpoint_path is not None: print("Continue training from checkpoint: %s" % checkpoint_path) try: checkpoint = torch.load(checkpoint_path, map_location='cpu') model.load_state_dict(checkpoint['model']) except: print(" > Partial model initialization.") model_dict = model.state_dict() model_dict = set_init_dict(model_dict, checkpoint, c) model.load_state_dict(model_dict) del model_dict step = 0 else: print("Starting new training run") step = 0 if c.train_config['lr_decay']: scheduler = NoamLR(optimizer, warmup_steps=c.train_config['warmup_steps'], last_epoch=step - 1) else: scheduler = None # convert model from cuda if cuda: model = model.cuda() # define loss function if use_mixup: criterion = Clip_BCE() else: criterion = nn.BCELoss() eval_criterion = nn.BCELoss(reduction='sum') best_loss = float('inf') # early stop definitions early_epochs = 0 model.train() for epoch in range(c.train_config['epochs']): for feature, target in trainloader: if cuda: feature = feature.cuda() target = target.cuda() if use_mixup: batch_len = len(feature) if (batch_len % 2) != 0: batch_len -= 1 feature = feature[:batch_len] target = target[:batch_len] mixup_lambda = torch.FloatTensor( mixup_augmenter.get_lambda(batch_len)).to(feature.device) output = model(feature[:batch_len], mixup_lambda) target = do_mixup(target, mixup_lambda) else: output = model(feature) # Calculate loss if c.dataset['class_balancer_batch'] and not use_mixup: idxs = (target == c.dataset['control_class']) loss_control = criterion(output[idxs], target[idxs]) idxs = (target == c.dataset['patient_class']) loss_patient = criterion(output[idxs], target[idxs]) loss = (loss_control + loss_patient) / 2 else: loss = criterion(output, target) optimizer.zero_grad() loss.backward() optimizer.step() # update lr decay scheme if scheduler: scheduler.step() step += 1 loss = loss.item() if loss > 1e8 or math.isnan(loss): print("Loss exploded to %.02f at step %d!" % (loss, step)) break # write loss to tensorboard if step % c.train_config['summary_interval'] == 0: tensorboard.log_training(loss, step) if c.dataset['class_balancer_batch'] and not use_mixup: print("Write summary at step %d" % step, ' Loss: ', loss, 'Loss control:', loss_control.item(), 'Loss patient:', loss_patient.item()) else: print("Write summary at step %d" % step, ' Loss: ', loss) # save checkpoint file and evaluate and save sample to tensorboard if step % c.train_config['checkpoint_interval'] == 0: save_path = os.path.join(log_dir, 'checkpoint_%d.pt' % step) torch.save( { 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'step': step, 'config_str': str(c), }, save_path) print("Saved checkpoint to: %s" % save_path) # run validation and save best checkpoint val_loss = validation(eval_criterion, ap, model, c, testloader, tensorboard, step, cuda=cuda, loss1_weight=loss1_weight) best_loss, _ = save_best_checkpoint( log_dir, model, optimizer, c, step, val_loss, best_loss, early_epochs if c.train_config['early_stop_epochs'] != 0 else None) print('=================================================') print("Epoch %d End !" % epoch) print('=================================================') # run validation and save best checkpoint at end epoch val_loss = validation(eval_criterion, ap, model, c, testloader, tensorboard, step, cuda=cuda, loss1_weight=loss1_weight) best_loss, early_epochs = save_best_checkpoint( log_dir, model, optimizer, c, step, val_loss, best_loss, early_epochs if c.train_config['early_stop_epochs'] != 0 else None) if c.train_config['early_stop_epochs'] != 0: if early_epochs is not None: if early_epochs >= c.train_config['early_stop_epochs']: break # stop train return best_loss