def train_model( idx_np, name: str, model_class: Type[nn.Module], graph: SparseGraph, model_args: dict, learning_rate: float, reg_lambda: float, stopping_args: dict = stopping_args, test: bool = True, device: str = 'cuda', torch_seed: int = None, print_interval: int = 10) -> Tuple[nn.Module, dict]: labels_all = graph.labels idx_all = {key: torch.LongTensor(val) for key, val in idx_np.items()} logging.log(21, f"{model_class.__name__}: {model_args}") if torch_seed is None: torch_seed = gen_seeds() torch.manual_seed(seed=torch_seed) logging.log(22, f"PyTorch seed: {torch_seed}") nfeatures = graph.attr_matrix.shape[1] nclasses = max(labels_all) + 1 model = model_class(nfeatures, nclasses, **model_args).to(device) reg_lambda = torch.tensor(reg_lambda, device=device) optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) dataloaders = get_dataloaders(idx_all, labels_all) early_stopping = EarlyStopping(model, **stopping_args) attr_mat_norm_np = normalize_attributes(graph.attr_matrix) attr_mat_norm = matrix_to_torch(attr_mat_norm_np).to(device) epoch_stats = {'train': {}, 'stopping': {}} start_time = time.time() last_time = start_time for epoch in range(early_stopping.max_epochs): for phase in epoch_stats.keys(): if phase == 'train': model.train() # Set model to training mode else: model.eval() # Set model to evaluate mode running_loss = 0 running_corrects = 0 for idx, labels in dataloaders[phase]: idx = idx.to(device) labels = labels.to(device) optimizer.zero_grad() with torch.set_grad_enabled(phase == 'train'): log_preds = model(attr_mat_norm, idx) preds = torch.argmax(log_preds, dim=1) # Calculate loss cross_entropy_mean = F.nll_loss(log_preds, labels) l2_reg = sum((torch.sum(param ** 2) for param in model.reg_params)) loss = cross_entropy_mean + reg_lambda / 2 * l2_reg if phase == 'train': loss.backward() optimizer.step() # Collect statistics running_loss += loss.item() * idx.size(0) running_corrects += torch.sum(preds == labels) # Collect statistics epoch_stats[phase]['loss'] = running_loss / len(dataloaders[phase].dataset) epoch_stats[phase]['acc'] = running_corrects.item() / len(dataloaders[phase].dataset) if epoch % print_interval == 0: duration = time.time() - last_time last_time = time.time() print(f"Epoch{epoch}: " f"Train loss = {epoch_stats['train']['loss']:.2f}, " f"train acc = {epoch_stats['train']['acc'] * 100:.1f}, " f"early stopping loss = {epoch_stats['stopping']['loss']:.2f}, " f"early stopping acc = {epoch_stats['stopping']['acc'] * 100:.1f} " f"({duration:.3f} sec)") logging.info(f"Epoch {epoch}: " f"Train loss = {epoch_stats['train']['loss']:.2f}, " f"train acc = {epoch_stats['train']['acc'] * 100:.1f}, " f"early stopping loss = {epoch_stats['stopping']['loss']:.2f}, " f"early stopping acc = {epoch_stats['stopping']['acc'] * 100:.1f} " f"({duration:.3f} sec)") if len(early_stopping.stop_vars) > 0: stop_vars = [epoch_stats['stopping'][key] for key in early_stopping.stop_vars] if early_stopping.check(stop_vars, epoch): break runtime = time.time() - start_time runtime_perepoch = runtime / (epoch + 1) logging.log(22, f"Last epoch: {epoch}, best epoch: {early_stopping.best_epoch} ({runtime:.3f} sec)") # Load best model weights model.load_state_dict(early_stopping.best_state, False) train_preds = get_predictions(model, attr_mat_norm, idx_all['train']) train_acc = (train_preds == labels_all[idx_all['train']]).mean() stopping_preds = get_predictions(model, attr_mat_norm, idx_all['stopping']) stopping_acc = (stopping_preds == labels_all[idx_all['stopping']]).mean() logging.log(21, f"Early stopping accuracy: {stopping_acc * 100:.1f}%") valtest_preds = get_predictions(model, attr_mat_norm, idx_all['valtest']) valtest_acc = (valtest_preds == labels_all[idx_all['valtest']]).mean() valtest_name = 'Test' if test else 'Validation' logging.log(22, f"{valtest_name} accuracy: {valtest_acc * 100:.1f}%") result = {} result['predictions'] = get_predictions(model, attr_mat_norm, torch.arange(len(labels_all))) result['train'] = {'accuracy': train_acc} result['early_stopping'] = {'accuracy': stopping_acc} result['valtest'] = {'accuracy': valtest_acc} result['runtime'] = runtime result['runtime_perepoch'] = runtime_perepoch return model, result
def main(args): # fix random seeds torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) np.random.seed(args.seed) device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu") print(device) criterion = nn.CrossEntropyLoss() cluster_log = Logger(os.path.join(args.exp, 'clusters.pickle')) # CNN if args.verbose: print('Architecture: {}'.format(args.arch)) ''' ########################################## ########################################## # Model definition ########################################## ##########################################''' model = models.__dict__[args.arch](bn=True, num_cluster=args.nmb_cluster, num_category=args.nmb_category) fd = int(model.cluster_layer[0].weight.size()[1]) # due to transpose, fd is input dim of W (in dim, out dim) model.cluster_layer = None model.category_layer = None model.features = torch.nn.DataParallel(model.features) model = model.double() model.to(device) cudnn.benchmark = True if args.optimizer is 'Adam': print('Adam optimizer: conv') optimizer_body = torch.optim.Adam( filter(lambda x: x.requires_grad, model.parameters()), lr=args.lr_Adam, betas=(0.9, 0.999), weight_decay=10 ** args.wd, ) else: print('SGD optimizer: conv') optimizer_body = torch.optim.SGD( filter(lambda x: x.requires_grad, model.parameters()), lr=args.lr_SGD, momentum=args.momentum, weight_decay=10 ** args.wd, ) ''' ############### ############### category_layer ############### ############### ''' model.category_layer = nn.Sequential( nn.Linear(fd, args.nmb_category), nn.Softmax(dim=1), ) model.category_layer[0].weight.data.normal_(0, 0.01) model.category_layer[0].bias.data.zero_() model.category_layer = model.category_layer.double() model.category_layer.to(device) ''' ############################ ############################ # EarlyStopping (test_accuracy_bal, 100) ############################ ############################ ''' early_stopping = EarlyStopping(model, **stopping_args) stop_vars = [] if args.optimizer is 'Adam': print('Adam optimizer: conv') optimizer_category = torch.optim.Adam( filter(lambda x: x.requires_grad, model.category_layer.parameters()), lr=args.lr_Adam, betas=(0.9, 0.999), weight_decay=10 ** args.wd, ) else: print('SGD optimizer: conv') optimizer_category = torch.optim.SGD( filter(lambda x: x.requires_grad, model.category_layer.parameters()), lr=args.lr_SGD, momentum=args.momentum, weight_decay=10 ** args.wd, ) ''' ######################################## ######################################## Create echogram sampling index ######################################## ########################################''' print('Sample echograms.') dataset_cp, dataset_semi = sampling_echograms_full(args) dataloader_cp = torch.utils.data.DataLoader(dataset_cp, shuffle=False, batch_size=args.batch, num_workers=args.workers, drop_last=False, pin_memory=True) dataloader_semi = torch.utils.data.DataLoader(dataset_semi, shuffle=False, batch_size=args.batch, num_workers=args.workers, drop_last=False, pin_memory=True) dataset_test_bal, dataset_test_unbal = sampling_echograms_test(args) dataloader_test_bal = torch.utils.data.DataLoader(dataset_test_bal, shuffle=False, batch_size=args.batch, num_workers=args.workers, drop_last=False, pin_memory=True) dataloader_test_unbal = torch.utils.data.DataLoader(dataset_test_unbal, shuffle=False, batch_size=args.batch, num_workers=args.workers, drop_last=False, pin_memory=True) # clustering algorithm to use deepcluster = clustering.__dict__[args.clustering](args.nmb_cluster, args.pca) # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] # remove top located layer parameters from checkpoint copy_checkpoint_state_dict = checkpoint['state_dict'].copy() for key in list(copy_checkpoint_state_dict): if 'cluster_layer' in key: del copy_checkpoint_state_dict[key] # if 'category_layer' in key: # del copy_checkpoint_state_dict[key] checkpoint['state_dict'] = copy_checkpoint_state_dict model.load_state_dict(checkpoint['state_dict']) optimizer_body.load_state_dict(checkpoint['optimizer_body']) optimizer_category.load_state_dict(checkpoint['optimizer_category']) category_save = os.path.join(args.exp, 'category_layer.pth.tar') if os.path.isfile(category_save): category_layer_param = torch.load(category_save) model.category_layer.load_state_dict(category_layer_param) print("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) # creating checkpoint repo exp_check = os.path.join(args.exp, 'checkpoints') if not os.path.isdir(exp_check): os.makedirs(exp_check) exp_bal = os.path.join(args.exp, 'bal') exp_unbal = os.path.join(args.exp, 'unbal') for dir_bal in [exp_bal, exp_unbal]: for dir_2 in ['features', 'pca_features', 'pred']: dir_to_make = os.path.join(dir_bal, dir_2) if not os.path.isdir(dir_to_make): os.makedirs(dir_to_make) if os.path.isfile(os.path.join(args.exp, 'loss_collect.pickle')): with open(os.path.join(args.exp, 'loss_collect.pickle'), "rb") as f: loss_collect = pickle.load(f) else: loss_collect = [[], [], [], [], [], [], [], [], []] if os.path.isfile(os.path.join(args.exp, 'nmi_collect.pickle')): with open(os.path.join(args.exp, 'nmi_collect.pickle'), "rb") as ff: nmi_save = pickle.load(ff) else: nmi_save = [] ''' ####################### ####################### MAIN TRAINING ####################### #######################''' for epoch in range(args.start_epoch, early_stopping.max_epochs): end = time.time() print('##################### Start training at Epoch %d ################'% epoch) model.classifier = nn.Sequential(*list(model.classifier.children())[:-1]) # remove ReLU at classifier [:-1] model.cluster_layer = None model.category_layer = None ''' ####################### ####################### PSEUDO-LABEL GENERATION ####################### ####################### ''' print('Cluster the features') features_train, input_tensors_train, labels_train = compute_features(dataloader_cp, model, len(dataset_cp), device=device, args=args) clustering_loss, pca_features = deepcluster.cluster(features_train, verbose=args.verbose) nan_location = np.isnan(pca_features) inf_location = np.isinf(pca_features) if (not np.allclose(nan_location, 0)) or (not np.allclose(inf_location, 0)): print('PCA: Feature NaN or Inf found. Nan count: ', np.sum(nan_location), ' Inf count: ', np.sum(inf_location)) print('Skip epoch ', epoch) torch.save(pca_features, 'tr_pca_NaN_%d.pth.tar' % epoch) torch.save(features_train, 'tr_feature_NaN_%d.pth.tar' % epoch) continue print('Assign pseudo labels') size_cluster = np.zeros(len(deepcluster.images_lists)) for i, _list in enumerate(deepcluster.images_lists): size_cluster[i] = len(_list) print('size in clusters: ', size_cluster) img_label_pair_train = zip_img_label(input_tensors_train, labels_train) train_dataset = clustering.cluster_assign(deepcluster.images_lists, img_label_pair_train) # Reassigned pseudolabel # uniformly sample per target sampler_train = UnifLabelSampler(int(len(train_dataset)), deepcluster.images_lists) train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch, shuffle=False, num_workers=args.workers, sampler=sampler_train, pin_memory=True, ) ''' #################################################################### #################################################################### TRSNSFORM MODEL FOR SELF-SUPERVISION // SEMI-SUPERVISION #################################################################### #################################################################### ''' # Recover classifier with ReLU (that is not used in clustering) mlp = list(model.classifier.children()) # classifier that ends with linear(512 * 128). No ReLU at the end mlp.append(nn.ReLU(inplace=True).to(device)) model.classifier = nn.Sequential(*mlp) model.classifier.to(device) '''SELF-SUPERVISION (PSEUDO-LABELS)''' model.category_layer = None model.cluster_layer = nn.Sequential( nn.Linear(fd, args.nmb_cluster), # nn.Linear(4096, num_cluster), nn.Softmax(dim=1), # should be removed and replaced by ReLU for category_layer ) model.cluster_layer[0].weight.data.normal_(0, 0.01) model.cluster_layer[0].bias.data.zero_() model.cluster_layer = model.cluster_layer.double() model.cluster_layer.to(device) ''' train network with clusters as pseudo-labels ''' with torch.autograd.set_detect_anomaly(True): pseudo_loss, semi_loss, semi_accuracy = semi_train(train_dataloader, dataloader_semi, model, fd, criterion, optimizer_body, optimizer_category, epoch, device=device, args=args) # save checkpoint if (epoch + 1) % args.checkpoints == 0: path = os.path.join( args.exp, 'checkpoints', 'checkpoint_' + str(epoch) + '.pth.tar', ) if args.verbose: print('Save checkpoint at: {0}'.format(path)) torch.save({'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'optimizer_body': optimizer_body.state_dict(), 'optimizer_category': optimizer_category.state_dict(), }, path) ''' ############## ############## # TEST phase ############## ############## ''' test_loss_bal, test_accuracy_bal, test_pred_bal, test_label_bal = test(dataloader_test_bal, model, criterion, device, args) test_loss_unbal, test_accuracy_unbal, test_pred_unbal, test_label_unbal = test(dataloader_test_unbal, model, criterion, device, args) '''Save prediction of the test set''' if (epoch % args.save_epoch == 0): with open(os.path.join(args.exp, 'bal', 'pred', 'sup_epoch_%d_te_bal.pickle' % epoch), "wb") as f: pickle.dump([test_pred_bal, test_label_bal], f) with open(os.path.join(args.exp, 'unbal', 'pred', 'sup_epoch_%d_te_unbal.pickle' % epoch), "wb") as f: pickle.dump([test_pred_unbal, test_label_unbal], f) if args.verbose: print('###### Epoch [{0}] ###### \n' 'Time: {1:.3f} s\n' 'Pseudo tr_loss: {2:.3f} \n' 'SEMI tr_loss: {3:.3f} \n' 'TEST_bal loss: {4:.3f} \n' 'TEST_unbal loss: {5:.3f} \n' 'Clustering loss: {6:.3f} \n\n' 'SEMI accu: {7:.3f} \n' 'TEST_bal accu: {8:.3f} \n' 'TEST_unbal accu: {9:.3f} \n' .format(epoch, time.time() - end, pseudo_loss, semi_loss, test_loss_bal, test_loss_unbal, clustering_loss, semi_accuracy, test_accuracy_bal, test_accuracy_unbal)) try: nmi = normalized_mutual_info_score( clustering.arrange_clustering(deepcluster.images_lists), clustering.arrange_clustering(cluster_log.data[-1]) ) nmi_save.append(nmi) print('NMI against previous assignment: {0:.3f}'.format(nmi)) with open(os.path.join(args.exp, 'nmi_collect.pickle'), "wb") as ff: pickle.dump(nmi_save, ff) except IndexError: pass print('####################### \n') # save cluster assignments cluster_log.log(deepcluster.images_lists) # save running checkpoint torch.save({'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'optimizer_body': optimizer_body.state_dict(), 'optimizer_category': optimizer_category.state_dict(), }, os.path.join(args.exp, 'checkpoint.pth.tar')) torch.save(model.category_layer.state_dict(), os.path.join(args.exp, 'category_layer.pth.tar')) loss_collect[0].append(epoch) loss_collect[1].append(pseudo_loss) loss_collect[2].append(semi_loss) loss_collect[3].append(clustering_loss) loss_collect[4].append(test_loss_bal) loss_collect[5].append(test_loss_unbal) loss_collect[6].append(semi_accuracy) loss_collect[7].append(test_accuracy_bal) loss_collect[8].append(test_accuracy_unbal) with open(os.path.join(args.exp, 'loss_collect.pickle'), "wb") as f: pickle.dump(loss_collect, f) if (epoch % args.save_epoch == 0): out = produce_test_result_bal(epoch, model, dataloader_test_bal, dataset_test_bal, device, args, deepcluster) out = produce_test_result_unbal(epoch, model, dataloader_test_unbal, dataset_test_unbal, device, args, deepcluster) '''EarlyStopping''' if early_stopping.check(loss_collect[7], epoch): break out = produce_test_result_bal(epoch, model, dataloader_test_bal, dataset_test_bal, device, args, deepcluster) out = produce_test_result_unbal(epoch, model, dataloader_test_unbal, dataset_test_unbal, device, args, deepcluster) '''