Ejemplo n.º 1
0
def train_model(
        idx_np, name: str, model_class: Type[nn.Module], graph: SparseGraph, model_args: dict,
        learning_rate: float, reg_lambda: float,
        stopping_args: dict = stopping_args,
        test: bool = True, device: str = 'cuda',
        torch_seed: int = None, print_interval: int = 10) -> Tuple[nn.Module, dict]:

    labels_all = graph.labels
    idx_all = {key: torch.LongTensor(val) for key, val in idx_np.items()}

    logging.log(21, f"{model_class.__name__}: {model_args}")
    if torch_seed is None:
        torch_seed = gen_seeds()
    torch.manual_seed(seed=torch_seed)
    logging.log(22, f"PyTorch seed: {torch_seed}")

    nfeatures = graph.attr_matrix.shape[1]
    nclasses = max(labels_all) + 1
    model = model_class(nfeatures, nclasses, **model_args).to(device)

    reg_lambda = torch.tensor(reg_lambda, device=device)

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    dataloaders = get_dataloaders(idx_all, labels_all)
    early_stopping = EarlyStopping(model, **stopping_args)
    attr_mat_norm_np = normalize_attributes(graph.attr_matrix)
    attr_mat_norm = matrix_to_torch(attr_mat_norm_np).to(device)

    epoch_stats = {'train': {}, 'stopping': {}}

    start_time = time.time()
    last_time = start_time
    for epoch in range(early_stopping.max_epochs):
        for phase in epoch_stats.keys():

            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()  # Set model to evaluate mode

            running_loss = 0
            running_corrects = 0

            for idx, labels in dataloaders[phase]:
                idx = idx.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):


                    log_preds = model(attr_mat_norm, idx)
                    preds = torch.argmax(log_preds, dim=1)

                    # Calculate loss
                    cross_entropy_mean = F.nll_loss(log_preds, labels)
                    l2_reg = sum((torch.sum(param ** 2) for param in model.reg_params))
                    loss = cross_entropy_mean + reg_lambda / 2 * l2_reg

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                    # Collect statistics
                    running_loss += loss.item() * idx.size(0)
                    running_corrects += torch.sum(preds == labels)


            # Collect statistics
            epoch_stats[phase]['loss'] = running_loss / len(dataloaders[phase].dataset)
            epoch_stats[phase]['acc'] = running_corrects.item() / len(dataloaders[phase].dataset)

        if epoch % print_interval == 0:
            duration = time.time() - last_time
            last_time = time.time()
            print(f"Epoch{epoch}:  "
                         f"Train loss = {epoch_stats['train']['loss']:.2f}, "
                         f"train acc = {epoch_stats['train']['acc'] * 100:.1f}, "
                         f"early stopping loss = {epoch_stats['stopping']['loss']:.2f}, "
                         f"early stopping acc = {epoch_stats['stopping']['acc'] * 100:.1f} "
                         f"({duration:.3f} sec)")

            logging.info(f"Epoch {epoch}: "
                         f"Train loss = {epoch_stats['train']['loss']:.2f}, "
                         f"train acc = {epoch_stats['train']['acc'] * 100:.1f}, "
                         f"early stopping loss = {epoch_stats['stopping']['loss']:.2f}, "
                         f"early stopping acc = {epoch_stats['stopping']['acc'] * 100:.1f} "
                         f"({duration:.3f} sec)")

        if len(early_stopping.stop_vars) > 0:
            stop_vars = [epoch_stats['stopping'][key]
                         for key in early_stopping.stop_vars]
            if early_stopping.check(stop_vars, epoch):
                break
    runtime = time.time() - start_time
    runtime_perepoch = runtime / (epoch + 1)
    logging.log(22, f"Last epoch: {epoch}, best epoch: {early_stopping.best_epoch} ({runtime:.3f} sec)")

    # Load best model weights

    model.load_state_dict(early_stopping.best_state, False)

    train_preds = get_predictions(model, attr_mat_norm, idx_all['train'])
    train_acc = (train_preds == labels_all[idx_all['train']]).mean()

    stopping_preds = get_predictions(model, attr_mat_norm, idx_all['stopping'])
    stopping_acc = (stopping_preds == labels_all[idx_all['stopping']]).mean()
    logging.log(21, f"Early stopping accuracy: {stopping_acc * 100:.1f}%")

    valtest_preds = get_predictions(model, attr_mat_norm, idx_all['valtest'])
    valtest_acc = (valtest_preds == labels_all[idx_all['valtest']]).mean()
    valtest_name = 'Test' if test else 'Validation'
    logging.log(22, f"{valtest_name} accuracy: {valtest_acc * 100:.1f}%")
    
    result = {}
    result['predictions'] = get_predictions(model, attr_mat_norm, torch.arange(len(labels_all)))
    result['train'] = {'accuracy': train_acc}
    result['early_stopping'] = {'accuracy': stopping_acc}
    result['valtest'] = {'accuracy': valtest_acc}
    result['runtime'] = runtime
    result['runtime_perepoch'] = runtime_perepoch

    return model, result
Ejemplo n.º 2
0
def main(args):
    # fix random seeds
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")
    print(device)
    criterion = nn.CrossEntropyLoss()
    cluster_log = Logger(os.path.join(args.exp, 'clusters.pickle'))

    # CNN
    if args.verbose:
        print('Architecture: {}'.format(args.arch))

    '''
    ##########################################
    ##########################################
    # Model definition
    ##########################################
    ##########################################'''
    model = models.__dict__[args.arch](bn=True, num_cluster=args.nmb_cluster, num_category=args.nmb_category)
    fd = int(model.cluster_layer[0].weight.size()[1])  # due to transpose, fd is input dim of W (in dim, out dim)
    model.cluster_layer = None
    model.category_layer = None
    model.features = torch.nn.DataParallel(model.features)
    model = model.double()
    model.to(device)
    cudnn.benchmark = True

    if args.optimizer is 'Adam':
        print('Adam optimizer: conv')
        optimizer_body = torch.optim.Adam(
            filter(lambda x: x.requires_grad, model.parameters()),
            lr=args.lr_Adam,
            betas=(0.9, 0.999),
            weight_decay=10 ** args.wd,
        )
    else:
        print('SGD optimizer: conv')
        optimizer_body = torch.optim.SGD(
            filter(lambda x: x.requires_grad, model.parameters()),
            lr=args.lr_SGD,
            momentum=args.momentum,
            weight_decay=10 ** args.wd,
        )
    '''
    ###############
    ###############
    category_layer
    ###############
    ###############
    '''
    model.category_layer = nn.Sequential(
        nn.Linear(fd, args.nmb_category),
        nn.Softmax(dim=1),
    )
    model.category_layer[0].weight.data.normal_(0, 0.01)
    model.category_layer[0].bias.data.zero_()
    model.category_layer = model.category_layer.double()
    model.category_layer.to(device)

    '''
    ############################
    ############################
    # EarlyStopping (test_accuracy_bal, 100)
    ############################
    ############################
    '''
    early_stopping = EarlyStopping(model, **stopping_args)
    stop_vars = []

    if args.optimizer is 'Adam':
        print('Adam optimizer: conv')
        optimizer_category = torch.optim.Adam(
            filter(lambda x: x.requires_grad, model.category_layer.parameters()),
            lr=args.lr_Adam,
            betas=(0.9, 0.999),
            weight_decay=10 ** args.wd,
        )
    else:
        print('SGD optimizer: conv')
        optimizer_category = torch.optim.SGD(
            filter(lambda x: x.requires_grad, model.category_layer.parameters()),
            lr=args.lr_SGD,
            momentum=args.momentum,
            weight_decay=10 ** args.wd,
        )
    '''
    ########################################
    ########################################
    Create echogram sampling index
    ########################################
    ########################################'''

    print('Sample echograms.')
    dataset_cp, dataset_semi = sampling_echograms_full(args)
    dataloader_cp = torch.utils.data.DataLoader(dataset_cp,
                                                shuffle=False,
                                                batch_size=args.batch,
                                                num_workers=args.workers,
                                                drop_last=False,
                                                pin_memory=True)

    dataloader_semi = torch.utils.data.DataLoader(dataset_semi,
                                                shuffle=False,
                                                batch_size=args.batch,
                                                num_workers=args.workers,
                                                drop_last=False,
                                                pin_memory=True)

    dataset_test_bal, dataset_test_unbal = sampling_echograms_test(args)
    dataloader_test_bal = torch.utils.data.DataLoader(dataset_test_bal,
                                                shuffle=False,
                                                batch_size=args.batch,
                                                num_workers=args.workers,
                                                drop_last=False,
                                                pin_memory=True)

    dataloader_test_unbal = torch.utils.data.DataLoader(dataset_test_unbal,
                                                shuffle=False,
                                                batch_size=args.batch,
                                                num_workers=args.workers,
                                                drop_last=False,
                                                pin_memory=True)

    # clustering algorithm to use
    deepcluster = clustering.__dict__[args.clustering](args.nmb_cluster, args.pca)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            # remove top located layer parameters from checkpoint
            copy_checkpoint_state_dict = checkpoint['state_dict'].copy()
            for key in list(copy_checkpoint_state_dict):
                if 'cluster_layer' in key:
                    del copy_checkpoint_state_dict[key]
                # if 'category_layer' in key:
                #     del copy_checkpoint_state_dict[key]
            checkpoint['state_dict'] = copy_checkpoint_state_dict
            model.load_state_dict(checkpoint['state_dict'])
            optimizer_body.load_state_dict(checkpoint['optimizer_body'])
            optimizer_category.load_state_dict(checkpoint['optimizer_category'])
            category_save = os.path.join(args.exp,  'category_layer.pth.tar')
            if os.path.isfile(category_save):
                category_layer_param = torch.load(category_save)
                model.category_layer.load_state_dict(category_layer_param)
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # creating checkpoint repo
    exp_check = os.path.join(args.exp, 'checkpoints')
    if not os.path.isdir(exp_check):
        os.makedirs(exp_check)

    exp_bal = os.path.join(args.exp, 'bal')
    exp_unbal = os.path.join(args.exp, 'unbal')
    for dir_bal in [exp_bal, exp_unbal]:
        for dir_2 in ['features', 'pca_features', 'pred']:
            dir_to_make = os.path.join(dir_bal, dir_2)
            if not os.path.isdir(dir_to_make):
                os.makedirs(dir_to_make)

    if os.path.isfile(os.path.join(args.exp, 'loss_collect.pickle')):
        with open(os.path.join(args.exp, 'loss_collect.pickle'), "rb") as f:
            loss_collect = pickle.load(f)
    else:
        loss_collect = [[], [], [], [], [], [], [], [], []]

    if os.path.isfile(os.path.join(args.exp, 'nmi_collect.pickle')):
        with open(os.path.join(args.exp, 'nmi_collect.pickle'), "rb") as ff:
            nmi_save = pickle.load(ff)
    else:
        nmi_save = []
    '''
    #######################
    #######################
    MAIN TRAINING
    #######################
    #######################'''
    for epoch in range(args.start_epoch, early_stopping.max_epochs):
        end = time.time()
        print('#####################  Start training at Epoch %d ################'% epoch)
        model.classifier = nn.Sequential(*list(model.classifier.children())[:-1]) # remove ReLU at classifier [:-1]
        model.cluster_layer = None
        model.category_layer = None

        '''
        #######################
        #######################
        PSEUDO-LABEL GENERATION
        #######################
        #######################
        '''
        print('Cluster the features')
        features_train, input_tensors_train, labels_train = compute_features(dataloader_cp, model, len(dataset_cp), device=device, args=args)
        clustering_loss, pca_features = deepcluster.cluster(features_train, verbose=args.verbose)

        nan_location = np.isnan(pca_features)
        inf_location = np.isinf(pca_features)
        if (not np.allclose(nan_location, 0)) or (not np.allclose(inf_location, 0)):
            print('PCA: Feature NaN or Inf found. Nan count: ', np.sum(nan_location), ' Inf count: ', np.sum(inf_location))
            print('Skip epoch ', epoch)
            torch.save(pca_features, 'tr_pca_NaN_%d.pth.tar' % epoch)
            torch.save(features_train, 'tr_feature_NaN_%d.pth.tar' % epoch)
            continue

        print('Assign pseudo labels')
        size_cluster = np.zeros(len(deepcluster.images_lists))
        for i,  _list in enumerate(deepcluster.images_lists):
            size_cluster[i] = len(_list)
        print('size in clusters: ', size_cluster)
        img_label_pair_train = zip_img_label(input_tensors_train, labels_train)
        train_dataset = clustering.cluster_assign(deepcluster.images_lists,
                                                  img_label_pair_train)  # Reassigned pseudolabel

        # uniformly sample per target
        sampler_train = UnifLabelSampler(int(len(train_dataset)),
                                   deepcluster.images_lists)

        train_dataloader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch,
            shuffle=False,
            num_workers=args.workers,
            sampler=sampler_train,
            pin_memory=True,
        )
        '''
        ####################################################################
        ####################################################################
        TRSNSFORM MODEL FOR SELF-SUPERVISION // SEMI-SUPERVISION
        ####################################################################
        ####################################################################
        '''
        # Recover classifier with ReLU (that is not used in clustering)
        mlp = list(model.classifier.children()) # classifier that ends with linear(512 * 128). No ReLU at the end
        mlp.append(nn.ReLU(inplace=True).to(device))
        model.classifier = nn.Sequential(*mlp)
        model.classifier.to(device)

        '''SELF-SUPERVISION (PSEUDO-LABELS)'''
        model.category_layer = None
        model.cluster_layer = nn.Sequential(
            nn.Linear(fd, args.nmb_cluster),  # nn.Linear(4096, num_cluster),
            nn.Softmax(dim=1),  # should be removed and replaced by ReLU for category_layer
        )
        model.cluster_layer[0].weight.data.normal_(0, 0.01)
        model.cluster_layer[0].bias.data.zero_()
        model.cluster_layer = model.cluster_layer.double()
        model.cluster_layer.to(device)

        ''' train network with clusters as pseudo-labels '''
        with torch.autograd.set_detect_anomaly(True):
            pseudo_loss, semi_loss, semi_accuracy = semi_train(train_dataloader, dataloader_semi, model, fd, criterion,
                                                               optimizer_body, optimizer_category, epoch, device=device, args=args)

        # save checkpoint
        if (epoch + 1) % args.checkpoints == 0:
            path = os.path.join(
                args.exp,
                'checkpoints',
                'checkpoint_' + str(epoch) + '.pth.tar',
            )
            if args.verbose:
                print('Save checkpoint at: {0}'.format(path))
            torch.save({'epoch': epoch + 1,
                        'arch': args.arch,
                        'state_dict': model.state_dict(),
                        'optimizer_body': optimizer_body.state_dict(),
                        'optimizer_category': optimizer_category.state_dict(),
                        }, path)


        '''
        ##############
        ##############
        # TEST phase
        ##############
        ##############
        '''
        test_loss_bal, test_accuracy_bal, test_pred_bal, test_label_bal = test(dataloader_test_bal, model, criterion, device, args)
        test_loss_unbal, test_accuracy_unbal, test_pred_unbal, test_label_unbal = test(dataloader_test_unbal, model, criterion, device, args)

        '''Save prediction of the test set'''
        if (epoch % args.save_epoch == 0):
            with open(os.path.join(args.exp, 'bal', 'pred', 'sup_epoch_%d_te_bal.pickle' % epoch), "wb") as f:
                pickle.dump([test_pred_bal, test_label_bal], f)
            with open(os.path.join(args.exp, 'unbal', 'pred', 'sup_epoch_%d_te_unbal.pickle' % epoch), "wb") as f:
                pickle.dump([test_pred_unbal, test_label_unbal], f)

        if args.verbose:
            print('###### Epoch [{0}] ###### \n'
                  'Time: {1:.3f} s\n'
                  'Pseudo tr_loss: {2:.3f} \n'
                  'SEMI tr_loss: {3:.3f} \n'
                  'TEST_bal loss: {4:.3f} \n'
                  'TEST_unbal loss: {5:.3f} \n'
                  'Clustering loss: {6:.3f} \n\n'
                  'SEMI accu: {7:.3f} \n'
                  'TEST_bal accu: {8:.3f} \n'
                  'TEST_unbal accu: {9:.3f} \n'
                  .format(epoch, time.time() - end, pseudo_loss, semi_loss,
                          test_loss_bal, test_loss_unbal, clustering_loss, semi_accuracy, test_accuracy_bal, test_accuracy_unbal))
            try:
                nmi = normalized_mutual_info_score(
                    clustering.arrange_clustering(deepcluster.images_lists),
                    clustering.arrange_clustering(cluster_log.data[-1])
                )
                nmi_save.append(nmi)
                print('NMI against previous assignment: {0:.3f}'.format(nmi))
                with open(os.path.join(args.exp, 'nmi_collect.pickle'), "wb") as ff:
                    pickle.dump(nmi_save, ff)
            except IndexError:
                pass
            print('####################### \n')

        # save cluster assignments
        cluster_log.log(deepcluster.images_lists)

        # save running checkpoint
        torch.save({'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'optimizer_body': optimizer_body.state_dict(),
                    'optimizer_category': optimizer_category.state_dict(),
                    },
                   os.path.join(args.exp, 'checkpoint.pth.tar'))
        torch.save(model.category_layer.state_dict(), os.path.join(args.exp, 'category_layer.pth.tar'))

        loss_collect[0].append(epoch)
        loss_collect[1].append(pseudo_loss)
        loss_collect[2].append(semi_loss)
        loss_collect[3].append(clustering_loss)
        loss_collect[4].append(test_loss_bal)
        loss_collect[5].append(test_loss_unbal)
        loss_collect[6].append(semi_accuracy)
        loss_collect[7].append(test_accuracy_bal)
        loss_collect[8].append(test_accuracy_unbal)
        with open(os.path.join(args.exp, 'loss_collect.pickle'), "wb") as f:
            pickle.dump(loss_collect, f)

        if (epoch % args.save_epoch == 0):
            out = produce_test_result_bal(epoch, model, dataloader_test_bal, dataset_test_bal, device, args, deepcluster)
            out = produce_test_result_unbal(epoch, model, dataloader_test_unbal, dataset_test_unbal, device, args, deepcluster)

        '''EarlyStopping'''
        if early_stopping.check(loss_collect[7], epoch):
            break

    out = produce_test_result_bal(epoch, model, dataloader_test_bal, dataset_test_bal, device, args, deepcluster)
    out = produce_test_result_unbal(epoch, model, dataloader_test_unbal, dataset_test_unbal, device, args,
                                        deepcluster)


        '''