Пример #1
0
def load_models(settings):
    device = torch.device(
        "cuda") if torch.cuda.is_available() else torch.device("cpu")
    model_class = ClassifierNet(settings).to(device)
    model_emb = EmbeddingNet(settings).to(device)
    model_emb, model_class = load_parameters_model(settings, model_emb,
                                                   model_class)
    model_emb.eval()
    model_class.eval()
    return model_emb, model_class
Пример #2
0
def main(args):
    PATH = os.path.dirname(os.path.realpath(__file__))
    PATH_TRAIN = args.path_train
    FILE_NAME = f'{args.train_file}_{args.random_state}'
    global MODEL_NAME
    MODEL_NAME = args.model_name
    training_log = PATH + f'/training_log/{args.model_name}_training_{args.log}.log'
    with open(training_log, 'a') as f:
        message = f'Training log {args.log} of {args.model_name} \n\n'
        message += f'Starts at {datetime.datetime.now()}\n'
        message += 'Arguments are: \n'
        message += f'{str(args)}\n\n'
        f.write(message)
        f.flush()

    cuda = torch.cuda.is_available()
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    message = f'Training on {torch.cuda.device_count()} {torch.cuda.get_device_name()}\n'
    with open(training_log, 'a') as f:
        f.write(message + '\n')
        f.flush()

    model_params = args.model_params
    # initialize the model
    embedding_net = EmbeddingNet(
        in_channels=grid_input_channels_dict[args.grid_type], **model_params)
    model = TripletNet(embedding_net)
    start_epoch = 0
    message = f'Initialize the model architecture\n'
    # load saved model
    if args.continue_training:
        if args.saved_model is None:
            message = f'Missing saved model name\n'
            with open(training_log, 'a') as f:
                f.write(message)
                f.flush()
            raise ValueError(message)

        message += f'Read saved model {args.saved_model}\n'
        start_epoch = int(re.search(r'Epoch(\d+)', args.saved_model).group(1))
        if cuda:
            map_location = None
        else:
            map_location = torch.device('cpu')
        state_dict = torch.load(f'{PATH}/saved_model/{args.saved_model}',
                                map_location=map_location)
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = k.replace('module.', '')  # removing ‘moldule.’ from key
            new_state_dict[name] = v
        model.load_state_dict(new_state_dict)

    def model_initialization_method(method):
        if method == 'xavier_normal':

            def weights_init(m):
                if isinstance(m, nn.Conv3d):
                    nn.init.xavier_normal_(m.weight.data)
                    if m.bias is not None:
                        torch.nn.init.zeros_(m.bias)

            return weights_init
        if method == 'xavier_uniform':

            def weights_init(m):
                if isinstance(m, nn.Conv3d):
                    nn.init.xavier_uniform_(m.weight.data)
                    if m.bias is not None:
                        torch.nn.init.zeros_(m.bias)

            return weights_init

    if args.model_init != 'default':
        model.apply(model_initialization_method(args.model_init))
    message += f'Model initialization method:{args.model_init}\n'
    pytorch_total_params = sum(p.numel() for p in model.parameters()
                               if p.requires_grad)

    # print total number of trainable parameters and model architecture
    with open(training_log, 'a') as f:
        message += f'Model architecture:\n{str(model)}\n'
        message += f'Total number of trainable parameters is {pytorch_total_params}\n'
        message += f'Training starts at : {datetime.datetime.now()}\n'
        f.write(message)
        f.flush()

    # multi gpu
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
    model.to(device)

    def dataset_class(data_type):
        if data_type == 'molcoord':
            return MolCoordDataset, TripletMolCoord
        if data_type == 'grid':
            return SingleGridDataset, TripletSingleGrid

    DS, TripletDS = dataset_class(args.data_type)

    if args.data_type == 'molcoord':
        t0 = datetime.datetime.now()
        df = pd.read_pickle(f'{PATH_TRAIN}{FILE_NAME}.pkl')
        train_index, test_index = train_test_index(
            df.shape[0], random_state=args.random_state)
        # MolCoordDataset
        molcoords = df.mol.apply(
            lambda x: MolCoord.from_sdf(Chem.MolToMolBlock(x))).tolist()
        # coordination rotation randomly
        if args.grid_rotation == 1:
            np.random.seed(args.random_state)
            axis = np.random.rand(df.shape[0], 3) * 2 - 1
            theta = np.random.rand(df.shape[0]) * np.pi * 2
            for i in range(len(molcoords)):
                matrix = torch.Tensor(matrix_from_axis_angel(
                    axis[i], theta[i]))
                molcoords[i].coord_rotation_(matrix)

        train_dataset = DS([molcoords[index] for index in train_index],
                           np.zeros(len(train_index), dtype=int),
                           grid_type=args.grid_type,
                           train=True)
        test_dataset = DS([molcoords[index] for index in test_index],
                          np.zeros(len(test_index), dtype=int),
                          grid_type=args.grid_type,
                          train=False)
        with open(training_log, 'a') as f:
            message = f'Preparing dataset costs {datetime.datetime.now() - t0}'
            f.write(message)
            f.flush()

        #release unreferenced memory
        gc.collect()
        del df

    if args.data_type == 'grid':
        grid_path = f'{args.path_train}/grids/grid_{args.grid_type}'
        if args.grid_rotation == 1:
            grid_path += '_rot'
        grid_path += f'/{args.train_file}/{args.train_file}_{args.random_state}_grids'

        num_data = int(
            re.search(r'^.+training_(\d+)$', args.train_file).group(1))
        test_size = 0.2
        num_testdata = int(num_data * test_size)
        num_traindata = num_data - num_testdata

        t0 = datetime.datetime.now()
        # GridDataset
        train_dataset = SingleGridDataset(grid_path,
                                          np.zeros(num_traindata),
                                          train=True)
        test_dataset = SingleGridDataset(grid_path,
                                         np.zeros(num_testdata),
                                         train=False)
        with open(training_log, 'a') as f:
            message = f'Preparing dataset costs {datetime.datetime.now() - t0}'
            f.write(message + '\n')
            f.flush()

    batch_size = args.batch_size * torch.cuda.device_count()
    margin = args.margin
    loss_fn = WeightedTripletLoss(margin)
    lr = args.learning_rate
    optimizer = optim.Adam(model.parameters(), lr=lr)
    scheduler = lr_scheduler.StepLR(optimizer,
                                    step_size=8,
                                    gamma=0.1,
                                    last_epoch=-1)
    n_epochs = args.n_epochs
    log_interval = args.log_interval

    metrics = []
    es = EarlyStopping(patience=args.es_patience, min_delta=args.es_min_delta)
    min_val_loss = np.inf
    early_stopping_counter = 0
    es_indicator = False
    ckpt_interval = 3

    # save the model after initialization
    if start_epoch == 0:
        with open(training_log, 'a') as f:
            save(model, 'completed', start_epoch, f, args.log)

    for epoch in range(0, start_epoch):
        scheduler.step()

    for epoch in range(start_epoch, n_epochs):
        optimizer.step()
        scheduler.step()

        t_epoch_start = datetime.datetime.now()
        message = f'Epoch {epoch + 1} starts at {t_epoch_start}'
        with open(training_log, 'a') as f:
            f.write(message + '\n')
            f.flush()
        t0 = datetime.datetime.now()
        # clustering embedding vectors to get labels
        if torch.cuda.device_count() > 1:
            embedding_net = model.module.embedding_net
            embedding_net = nn.DataParallel(embedding_net)
        else:
            embedding_net = model.embedding_net
        embedding_net.eval()
        embedding = []
        for batch_index in divide_batches(list(range(num_traindata)),
                                          batch_size):
            with torch.no_grad():
                embedding.append(
                    embedding_net(
                        train_dataset[batch_index]['grid'].cuda()).cpu())
        for batch_index in divide_batches(list(range(num_testdata)),
                                          batch_size):
            with torch.no_grad():
                embedding.append(
                    embedding_net(
                        test_dataset[batch_index]['grid'].cuda()).cpu())
        embedding = torch.cat(embedding)
        message = f'Epoch {epoch + 1} embedding computation costs {datetime.datetime.now() - t0}'
        with open(training_log, 'a') as f:
            f.write(message + '\n')
            f.flush()

        min_counts_labels = 1
        cls_counter = 1
        random_state = args.random_state

        while min_counts_labels < 2:
            t0 = datetime.datetime.now()
            message = f'Epoch {epoch + 1} clustering {cls_counter} starts at {t0}\n'
            with open(training_log, 'a') as f:
                f.write(message)
                f.flush()
            random.seed(random_state)
            random_state = int(random.random() * 1e6)
            kwargs = {'verbose': 1}
            cls = clustering_method(args.clustering_method, args.num_clusters,
                                    random_state, **kwargs)
            cls.fit(embedding)
            labels = cls.predict(embedding)
            train_labels = labels[:num_traindata]
            test_labels = labels[num_traindata:]
            #             unique_labels, counts_labels = np.unique(labels, return_counts=True)
            unique_labels_train, counts_labels_train = np.unique(
                train_labels, return_counts=True)
            unique_labels_test, counts_labels_test = np.unique(
                test_labels, return_counts=True)
            min_counts_labels = min(min(counts_labels_train),
                                    min(counts_labels_test))
            message = f'Epoch {epoch + 1} clustering {cls_counter} ends at {datetime.datetime.now()}\n'
            message += f'Epoch {epoch + 1} clustering {cls_counter} costs {datetime.datetime.now() - t0}\n'
            message += f'{len(unique_labels_train)} clusters for train in total\n'
            message += f'The minimum number of samples in a cluster for train is {min(counts_labels_train)}\n'
            message += f'The maximum number of samples in a cluster for train is {max(counts_labels_train)}\n'
            message += f'{len(unique_labels_test)} clusters for test in total\n'
            message += f'The minimum number of samples in a cluster for test is {min(counts_labels_test)}\n'
            message += f'The maximum number of samples in a cluster for test is {max(counts_labels_test)}\n'
            with open(training_log, 'a') as f:
                f.write(message + '\n')
                f.flush()
            cls_counter += 1
            if cls_counter > 10:
                break

        if min_counts_labels < 2:
            with open(training_log, 'a') as f:
                message = f'Cannot get good clustering results. Stop training.\n'
                f.write(message + '\n')
                f.flush()
            break

        if args.weighted_loss:
            loss_weights_train = dict(
                zip(
                    unique_labels_train, 1 / counts_labels_train *
                    len(train_labels) / len(unique_labels_train)))
            loss_weights_test = dict(
                zip(
                    unique_labels_test, 1 / counts_labels_test *
                    len(test_labels) / len(unique_labels_test)))
        else:
            loss_weights_train = dict(
                zip(unique_labels_train, np.ones(len(unique_labels_train))))
            loss_weights_test = dict(
                zip(unique_labels_test, np.ones(len(unique_labels_test))))
        train_dataset.labels = train_labels
        test_dataset.labels = test_labels

        t0 = datetime.datetime.now()
        kwargs = {'num_workers': 0, 'pin_memory': True} if cuda else {}
        train_loader = torch.utils.data.DataLoader(TripletDS(
            train_dataset, grid_rotation=args.grid_rotation),
                                                   batch_size=batch_size,
                                                   shuffle=True,
                                                   **kwargs)
        test_loader = torch.utils.data.DataLoader(TripletDS(
            test_dataset, grid_rotation=args.grid_rotation),
                                                  batch_size=batch_size,
                                                  shuffle=False,
                                                  **kwargs)
        message = f'Epoch {epoch + 1} dataloader preparation costs {datetime.datetime.now() - t0}'
        with open(training_log, 'a') as f:
            f.write(message + '\n')
            f.flush()

        # Train stage
        train_loss, metrics = train_epoch(
            train_loader,
            model,
            loss_fn,
            loss_weights_train,
            optimizer,
            cuda,
            log_interval,
            training_log,
            metrics,
            epoch,
            ckpt_interval,
        )

        message = f'Epoch: {epoch + 1}/{n_epochs}. Train set: Average loss: {train_loss:.4f}'

        for metric in metrics:
            message += f'\t{metric.name()}: {metric.value()}'

        val_loss, metrics = test_epoch(test_loader, model, loss_fn,
                                       loss_weights_test, cuda, metrics, epoch)
        val_loss /= len(test_loader)

        message += f'\nEpoch: {epoch + 1}/{n_epochs}. Validation set: Average loss: {val_loss:.4f}'
        for metric in metrics:
            message += f'\t{metric.name()}: {metric.value()}\n'

        message += f'\nEpoch {epoch + 1} costs: {str(datetime.datetime.now() - t_epoch_start)}\n'
        print(message)
        with open(training_log, 'a') as f:
            f.write(message + '\n')
            f.flush()

        # output train loss and validation loss
        with open(
                re.search(r'(.+)\.log', training_log).group(1) + '_loss.csv',
                'a') as f_loss:
            message = f'{epoch + 1},{train_loss},{val_loss}\n'
            f_loss.write(message)
            f_loss.flush()

        min_val_loss, early_stopping_counter, es_indicator = es(
            val_loss, min_val_loss, early_stopping_counter)

        # save models with improvement on test loss
        if early_stopping_counter == 0:
            with open(training_log, 'a') as f:
                save(model, 'completed', epoch + 1, f, args.log)
        # save models with no improvement on test loss
        else:
            with open(training_log, 'a') as f:
                save(model, 'completed', epoch + 1, f, args.log)
                message = f'min_val_loss: {min_val_loss:.4f}, # of epochs with no improvement: {early_stopping_counter}\n'
                print(message)
                f.write(message + '\n')
                f.flush()

        if es_indicator:
            message = f'min_val_loss: {min_val_loss:.4f}, # of epochs with no improvement: {early_stopping_counter}\n'
            message += f'Early Stopping after epoch {epoch + 1}\n'
            print(message)
            with open(training_log, 'a') as f:
                f.write(message + '\n')
                f.flush()
            break
        else:
            message = f'min_val_loss: {min_val_loss:.4f}, # of epochs with no improvement: {early_stopping_counter}\n'
            message += 'Training continued.\n'
            print(message)
            with open(training_log, 'a') as f:
                f.write(message + '\n')
                f.flush()
temp_dict = {}
for key in state_dict.keys():
    if key.startswith("embedding_net"):
        temp_dict[key[14:len(key)]] = state_dict[key]
embed_net.load_state_dict(temp_dict)


# function to generate embedding
def getEmbedding(file_path, x):
    file_name = os.path.join(file_path, x)
    bbox = bbox_df.loc[bbox_df.Image == x, :].values[0, 1:]
    img_pil = Image.open(file_name).crop(bbox).convert('RGB')
    img = np.array(img_pil)
    image = data_transforms_test(image=img)['image'].unsqueeze(0)
    vector = embed_net(image)
    return vector


# test
embed_net.eval()
train_embed_dataset = train_full.assign(
    embedding=train_full['Image'].apply(lambda x: getEmbedding('train/', x)))
print('training embedding generated !')
test_embed_dataset = test_df.assign(
    embedding=test_df['Image'].apply(lambda x: getEmbedding('test/', x)))
print('test embedding generated !')
pickle.dump(train_embed_dataset, open("train_full_embed.p", 'wb'))
print('training embedding saved ! at train_full_embed.p')
pickle.dump(test_embed_dataset, open("test_df_embed.p", 'wb'))
print('test embedding generated ! at train_full_embed.p')
Пример #4
0
        k = 0
        for images, target in dataloader:
            if cuda:
                images = images.cuda()
            embeddings[k:k + len(images)] = model.get_embedding(
                images).data.cpu().numpy()
            labels[k:k + len(images)] = target.numpy()
            k += len(images)
    return embeddings, labels


# In[4]:

model = EmbeddingNet()
model.load_state_dict(torch.load('./saved_model/titi'))
model.eval()

# In[5]:

final_test_epoch('/Users/ayush/projects/my_pytorch/probe',
                 '/Users/ayush/projects/my_pytorch/gallery',
                 '/Users/ayush/projects/my_pytorch/fp_output_txt',
                 model,
                 metrics=[AverageNonzeroTripletsMetric()],
                 transform=transforms.Compose([transforms.ToTensor()]))

# In[7]:

get_ipython().system('rm /Users/ayush/projects/my_pytorch/probe/.DS_Store')

# In[ ]: