Beispiel #1
0
def main():
    logger = logging.getLogger('logger')
    utilities.configure_logger(logger, console_only=True)
    parser = argparse.ArgumentParser()
    parser.add_argument('model_path', help="Path to the model to be evaluated")
    parser.add_argument(
        '-c',
        '--cuda',
        action='store_const',
        const=True,
        default=False,
        help=
        "Whether to enable calculation on the GPU through CUDA or not. Defaults to false."
    )

    cli_args = parser.parse_args()
    use_cuda = cli_args.cuda

    model = Siamese(dropout=False)
    if use_cuda:
        model.cuda()
    utils.network.load_model(model, cli_args.model_path, use_cuda=use_cuda)

    model = model.eval()
    data = VocalSketch_1_1()

    partitions = Partitions(data, PartitionSplit(.35, .15, .5))

    dataset = AllPairs(partitions.test)
    rrs = reciprocal_ranks(model, dataset, use_cuda)
    utilities.log_final_stats(rrs)
Beispiel #2
0
class Triplet(nn.Module):
    def __init__(self, dropout=True, normalization=True):
        super(Triplet, self).__init__()
        self.siamese = Siamese(dropout=dropout, normalization=normalization)

        linear_layer = nn.Linear(2, 1)

        init_weights = torch.Tensor([[50, -50]])
        init_bias = torch.Tensor([[0]])

        init_weights = init_weights.float()
        init_bias = init_bias.float()

        init_weights.requires_grad = False
        init_bias.requires_grad = False

        linear_layer.weight = torch.nn.Parameter(init_weights)
        linear_layer.bias = torch.nn.Parameter(init_bias)

        self.final_layer = nn.Sequential(linear_layer, nn.Sigmoid())

    def forward(self, query, near, far):
        # TODO: we can optimize this by only calculating the left/imitation branch once
        near_output = self.siamese(query, near)
        far_output = self.siamese(query, far)

        near_reshaped = near_output.view(len(near_output), -1)
        far_reshaped = far_output.view(len(far_output), -1)
        concatenated = torch.cat((near_reshaped, far_reshaped), dim=1)

        output = self.final_layer(concatenated)
        return output.view(-1)

    def load_siamese(self, model: nn.Module):
        self.siamese.load_state_dict(model.state_dict())
def main(config,args):
   
    
    use_cuda = config['use_gpu']
    device = torch.device("cuda" if use_cuda==1 else "cpu")
    model = Siamese()
    model = model.to(device)
    
    rec_loss = nn.L1Loss()
    cosine_loss = nn.CosineSimilarity(dim=1)
    optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0, betas=(0.9, 0.98), eps=1e-9)
    dataset_train = SpeechDataGenerator(args.clean_file,args.noisy_file,batch_s=100)
    dataloader_train = DataLoader(dataset_train, batch_size=1,shuffle=True,collate_fn=speech_collate) 
    for epoch in range(1, config['num_epochs'] + 1):
        train_loss=train(model,dataloader_train,epoch,optimizer,device,rec_loss,cosine_loss)
Beispiel #4
0
    def __init__(self, dropout=True, normalization=True):
        super(Triplet, self).__init__()
        self.siamese = Siamese(dropout=dropout, normalization=normalization)

        linear_layer = nn.Linear(2, 1)

        init_weights = torch.Tensor([[50, -50]])
        init_bias = torch.Tensor([[0]])

        init_weights = init_weights.float()
        init_bias = init_bias.float()

        init_weights.requires_grad = False
        init_bias.requires_grad = False

        linear_layer.weight = torch.nn.Parameter(init_weights)
        linear_layer.bias = torch.nn.Parameter(init_bias)

        self.final_layer = nn.Sequential(linear_layer, nn.Sigmoid())
Beispiel #5
0
def initialize_siamese_params(regenerate, dropout):
    logger = logging.getLogger('logger')
    starting_weights_path = "./model_output/siamese_init/starting_weights"

    model = Siamese(dropout=dropout)
    if not regenerate:
        load_model(model, starting_weights_path)

    logger.debug("Saving initial weights/biases at {0}...".format(starting_weights_path))
    save_model(model, starting_weights_path)

    trial_path = "./output/{0}/init_weights".format(get_trial_number())
    logger.debug("Saving initial weights/biases at {0}...".format(trial_path))
    save_model(model, trial_path)
Beispiel #6
0
def main(cli_args=None):
    utilities.update_trial_number()
    utilities.create_output_directory()

    logger = logging.getLogger('logger')
    parser = argparse.ArgumentParser()
    utilities.configure_parser(parser)
    utilities.configure_logger(logger)
    if cli_args is None:
        cli_args = parser.parse_args()

    logger.info('Beginning trial #{0}...'.format(utilities.get_trial_number()))
    log_cli_args(cli_args)
    try:
        datafiles = VocalImitation(recalculate_spectrograms=cli_args.recalculate_spectrograms)
        data_split = PartitionSplit(*cli_args.partitions)
        partitions = Partitions(datafiles, data_split, cli_args.num_categories, regenerate=False)
        partitions.generate_partitions(PairPartition, no_test=True)
        partitions.save("./output/{0}/partition.pickle".format(utilities.get_trial_number()))

        if cli_args.triplet:
            model = Triplet(dropout=cli_args.dropout)
        elif cli_args.pairwise:
            model = Siamese(dropout=cli_args.dropout)
        else:
            raise ValueError("You must specify the type of the model that is to be evaluated (triplet or pairwise")

        if cli_args.cuda:
            model = model.cuda()

        evaluated_epochs = np.arange(0, 300, step=5)
        model_directory = './model_output/{0}'.format('pairwise' if cli_args.pairwise else 'triplet') + '/model_{0}'
        model_paths = [model_directory.format(n) for n in evaluated_epochs]
        n_memorized = []
        memorized_var = []
        for model_path in model_paths:
            utils.network.load_model(model, model_path, cli_args.cuda)
            n, v = num_memorized_canonicals(model if cli_args.pairwise else model.siamese, AllPairs(partitions.train),
                                            cli_args.cuda)
            logger.info("n = {0}\nv={1}".format(n, v))
            n_memorized.append(n)
            memorized_var.append(v)

            num_canonical_memorized(memorized_var, n_memorized, evaluated_epochs[:len(n_memorized)], cli_args.num_categories)

    except Exception as e:
        logger.critical("Unhandled exception: {0}".format(str(e)))
        logger.critical(traceback.print_exc())
        sys.exit()
def train_siamese_network(model: Siamese, data: PairedDataset, objective, optimizer, n_epochs, use_cuda, batch_size=128):
    for epoch in range(n_epochs):
        # because the model is passed by reference and this is a generator, ensure that we're back in training mode
        model = model.train()

        # notify the dataset that an epoch has passed
        data.epoch_handler()

        batch_sampler = BatchSampler(BalancedPairSampler(data, batch_size), batch_size=batch_size, drop_last=False)
        train_data = DataLoader(data, batch_sampler=batch_sampler, num_workers=4)

        train_data_len = math.ceil(train_data.dataset.__len__() / batch_size)
        batch_losses = np.zeros(train_data_len)
        bar = Bar("Training siamese, epoch {0}".format(epoch), max=train_data_len)
        for i, (left, right, labels) in enumerate(train_data):
            # clear out the gradients
            optimizer.zero_grad()

            labels = labels.float()
            left = left.float()
            right = right.float()

            # reshape tensors and push to GPU if necessary
            left = left.unsqueeze(1)
            right = right.unsqueeze(1)
            if use_cuda:
                left = left.cuda()
                right = right.cuda()
                labels = labels.cuda()

            # pass a batch through the network
            outputs = model(left, right)

            # calculate loss and optimize weights
            loss = objective(outputs, labels)
            batch_losses[i] = loss.item()
            loss.backward()
            optimizer.step()

            bar.next()
        bar.finish()

        yield model, batch_losses
Beispiel #8
0
def siamese_loss(model: Siamese,
                 dataset,
                 objective,
                 use_cuda: bool,
                 batch_size=128):
    """
    Calculates the loss of model over dataset by objective. Optionally run on the GPU.
    :param model: a siamese network
    :param dataset: a dataset of imitation/reference pairs
    :param objective: loss function
    :param use_cuda: whether to run on GPU or not.
    :param batch_size: optional param to set batch_size. Defaults to 128.
    :return:
    """
    model = model.eval()
    dataset.epoch_handler()

    data = DataLoader(dataset, batch_size=batch_size, num_workers=4)
    bar = Bar("Calculating loss", max=len(data))
    batch_losses = np.zeros(len(data))
    for i, (left, right, labels) in enumerate(data):
        labels = labels.float()
        left = left.float()
        right = right.float()

        # reshape tensors and push to GPU if necessary
        left = left.unsqueeze(1)
        right = right.unsqueeze(1)
        if use_cuda:
            left = left.cuda()
            right = right.cuda()
            labels = labels.cuda()

        # pass a batch through the network
        outputs = model(left, right)

        # calculate loss and optimize weights
        batch_losses[i] = objective(outputs, labels).item()

        bar.next()
    bar.finish()

    return batch_losses
Beispiel #9
0
def pairwise_inference_matrix(model: Siamese, pairs_dataset: AllPairs,
                              use_cuda):
    """
    Calculates the pairwise inference matrix for a given model across a set of pairs (typically, all of them).

    :param model: siamese network
    :param pairs_dataset: dataset of desired pairs to calculate pairwise matrix across
    :param use_cuda: bool, whether to run on GPU
    :return: pairwise matrix
    """
    rrs = np.array([])
    pairs = dataloader.DataLoader(pairs_dataset, batch_size=128, num_workers=4)
    model = model.eval()
    bar = Bar("Calculating pairwise inference matrix", max=len(pairs))
    for imitations, references, label in pairs:

        label = label.float()
        imitations = imitations.float()
        references = references.float()

        # reshape tensors and push to GPU if necessary
        imitations = imitations.unsqueeze(1)
        references = references.unsqueeze(1)
        if use_cuda:
            imitations = imitations.cuda()
            references = references.cuda()

        output = model(imitations, references)
        # Detach the gradient, move to cpu, and convert to an ndarray
        np_output = output.detach().cpu().numpy()
        rrs = np.concatenate([rrs, np_output])

        bar.next()
    bar.finish()

    # Reshape vector into matrix
    rrs = rrs.reshape([pairs_dataset.n_imitations, pairs_dataset.n_references])
    return rrs
Beispiel #10
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--train', default='train_wiki', help='train file')
    parser.add_argument('--val', default='val_wiki', help='val file')
    parser.add_argument('--test', default='test_wiki', help='test file')
    parser.add_argument('--adv', default=None, help='adv file')
    parser.add_argument('--trainN', default=10, type=int, help='N in train')
    parser.add_argument('--N', default=5, type=int, help='N way')
    parser.add_argument('--K', default=5, type=int, help='K shot')
    parser.add_argument('--Q',
                        default=5,
                        type=int,
                        help='Num of query per class')
    parser.add_argument('--batch_size', default=4, type=int, help='batch size')
    parser.add_argument('--train_iter',
                        default=30000,
                        type=int,
                        help='num of iters in training')
    parser.add_argument('--val_iter',
                        default=1000,
                        type=int,
                        help='num of iters in validation')
    parser.add_argument('--test_iter',
                        default=10000,
                        type=int,
                        help='num of iters in testing')
    parser.add_argument('--val_step',
                        default=2000,
                        type=int,
                        help='val after training how many iters')
    parser.add_argument('--model', default='proto', help='model name')
    parser.add_argument('--encoder',
                        default='cnn',
                        help='encoder: cnn or bert or roberta')
    parser.add_argument('--max_length',
                        default=128,
                        type=int,
                        help='max length')
    parser.add_argument('--lr', default=1e-1, type=float, help='learning rate')
    parser.add_argument('--weight_decay',
                        default=1e-5,
                        type=float,
                        help='weight decay')
    parser.add_argument('--dropout',
                        default=0.0,
                        type=float,
                        help='dropout rate')
    parser.add_argument('--na_rate',
                        default=0,
                        type=int,
                        help='NA rate (NA = Q * na_rate)')
    parser.add_argument('--grad_iter',
                        default=1,
                        type=int,
                        help='accumulate gradient every x iterations')
    parser.add_argument('--optim', default='sgd', help='sgd / adam / adamw')
    parser.add_argument('--hidden_size',
                        default=230,
                        type=int,
                        help='hidden size')
    parser.add_argument('--load_ckpt', default=None, help='load ckpt')
    parser.add_argument('--save_ckpt', default=None, help='save ckpt')
    parser.add_argument('--fp16',
                        action='store_true',
                        help='use nvidia apex fp16')
    parser.add_argument('--only_test', action='store_true', help='only test')

    # only for bert / roberta
    parser.add_argument('--pair', action='store_true', help='use pair model')
    parser.add_argument('--pretrain_ckpt',
                        default=None,
                        help='bert / roberta pre-trained checkpoint')
    parser.add_argument(
        '--cat_entity_rep',
        action='store_true',
        help='concatenate entity representation as sentence rep')

    # only for prototypical networks
    parser.add_argument('--dot',
                        action='store_true',
                        help='use dot instead of L2 distance for proto')

    opt = parser.parse_args()
    trainN = opt.trainN
    N = opt.N
    K = opt.K
    Q = opt.Q
    batch_size = opt.batch_size
    model_name = opt.model
    encoder_name = opt.encoder
    max_length = opt.max_length

    print("{}-way-{}-shot Few-Shot Relation Classification".format(N, K))
    print("model: {}".format(model_name))
    print("encoder: {}".format(encoder_name))
    print("max_length: {}".format(max_length))

    if encoder_name == 'cnn':
        try:
            glove_mat = np.load('./pretrain/glove/glove_mat.npy')
            glove_word2id = json.load(
                open('./pretrain/glove/glove_word2id.json'))
        except:
            raise Exception(
                "Cannot find glove files. Run glove/download_glove.sh to download glove files."
            )
        sentence_encoder = CNNSentenceEncoder(glove_mat, glove_word2id,
                                              max_length)
    elif encoder_name == 'bert':
        pretrain_ckpt = opt.pretrain_ckpt or 'bert-base-uncased'
        if opt.pair:
            sentence_encoder = BERTPAIRSentenceEncoder(pretrain_ckpt,
                                                       max_length)
        else:
            sentence_encoder = BERTSentenceEncoder(
                pretrain_ckpt, max_length, cat_entity_rep=opt.cat_entity_rep)
    elif encoder_name == 'roberta':
        pretrain_ckpt = opt.pretrain_ckpt or 'roberta-base'
        if opt.pair:
            sentence_encoder = RobertaPAIRSentenceEncoder(
                pretrain_ckpt, max_length)
        else:
            sentence_encoder = RobertaSentenceEncoder(
                pretrain_ckpt, max_length, cat_entity_rep=opt.cat_entity_rep)
    else:
        raise NotImplementedError

    if opt.pair:
        train_data_loader = get_loader_pair(opt.train,
                                            sentence_encoder,
                                            N=trainN,
                                            K=K,
                                            Q=Q,
                                            na_rate=opt.na_rate,
                                            batch_size=batch_size,
                                            encoder_name=encoder_name)
        val_data_loader = get_loader_pair(opt.val,
                                          sentence_encoder,
                                          N=N,
                                          K=K,
                                          Q=Q,
                                          na_rate=opt.na_rate,
                                          batch_size=batch_size,
                                          encoder_name=encoder_name)
        test_data_loader = get_loader_pair(opt.test,
                                           sentence_encoder,
                                           N=N,
                                           K=K,
                                           Q=Q,
                                           na_rate=opt.na_rate,
                                           batch_size=batch_size,
                                           encoder_name=encoder_name)
    else:
        train_data_loader = get_loader(opt.train,
                                       sentence_encoder,
                                       N=trainN,
                                       K=K,
                                       Q=Q,
                                       na_rate=opt.na_rate,
                                       batch_size=batch_size)
        val_data_loader = get_loader(opt.val,
                                     sentence_encoder,
                                     N=N,
                                     K=K,
                                     Q=Q,
                                     na_rate=opt.na_rate,
                                     batch_size=batch_size)
        test_data_loader = get_loader(opt.test,
                                      sentence_encoder,
                                      N=N,
                                      K=K,
                                      Q=Q,
                                      na_rate=opt.na_rate,
                                      batch_size=batch_size)
        if opt.adv:
            adv_data_loader = get_loader_unsupervised(opt.adv,
                                                      sentence_encoder,
                                                      N=trainN,
                                                      K=K,
                                                      Q=Q,
                                                      na_rate=opt.na_rate,
                                                      batch_size=batch_size)

    if opt.optim == 'sgd':
        pytorch_optim = optim.SGD
    elif opt.optim == 'adam':
        pytorch_optim = optim.Adam
    elif opt.optim == 'adamw':
        from transformers import AdamW
        pytorch_optim = AdamW
    else:
        raise NotImplementedError
    if opt.adv:
        d = Discriminator(opt.hidden_size)
        framework = FewShotREFramework(train_data_loader,
                                       val_data_loader,
                                       test_data_loader,
                                       adv_data_loader,
                                       adv=opt.adv,
                                       d=d)
    else:
        framework = FewShotREFramework(train_data_loader, val_data_loader,
                                       test_data_loader)

    prefix = '-'.join(
        [model_name, encoder_name, opt.train, opt.val,
         str(N), str(K)])
    if opt.adv is not None:
        prefix += '-adv_' + opt.adv
    if opt.na_rate != 0:
        prefix += '-na{}'.format(opt.na_rate)
    if opt.dot:
        prefix += '-dot'
    if opt.cat_entity_rep:
        prefix += '-catentity'

    if model_name == 'proto':
        model = Proto(sentence_encoder, dot=opt.dot)
    elif model_name == 'gnn':
        model = GNN(sentence_encoder, N, hidden_size=opt.hidden_size)
    elif model_name == 'snail':
        model = SNAIL(sentence_encoder, N, K, hidden_size=opt.hidden_size)
    elif model_name == 'metanet':
        model = MetaNet(N, K, sentence_encoder.embedding, max_length)
    elif model_name == 'siamese':
        model = Siamese(sentence_encoder,
                        hidden_size=opt.hidden_size,
                        dropout=opt.dropout)
    elif model_name == 'pair':
        model = Pair(sentence_encoder, hidden_size=opt.hidden_size)
    else:
        raise NotImplementedError

    if not os.path.exists('checkpoint'):
        os.mkdir('checkpoint')
    ckpt = 'checkpoint/{}.pth.tar'.format(prefix)
    if opt.save_ckpt:
        ckpt = opt.save_ckpt

    if torch.cuda.is_available():
        model.cuda()

    if not opt.only_test:
        if encoder_name in ['bert', 'roberta']:
            bert_optim = True
        else:
            bert_optim = False

        framework.train(model,
                        prefix,
                        batch_size,
                        trainN,
                        N,
                        K,
                        Q,
                        pytorch_optim=pytorch_optim,
                        load_ckpt=opt.load_ckpt,
                        save_ckpt=ckpt,
                        na_rate=opt.na_rate,
                        val_step=opt.val_step,
                        fp16=opt.fp16,
                        pair=opt.pair,
                        train_iter=opt.train_iter,
                        val_iter=opt.val_iter,
                        bert_optim=bert_optim)
    else:
        ckpt = opt.load_ckpt

    acc = framework.eval(model,
                         batch_size,
                         N,
                         K,
                         Q,
                         opt.test_iter,
                         na_rate=opt.na_rate,
                         ckpt=ckpt,
                         pair=opt.pair)
    print("RESULT: %.2f" % (acc * 100))
                                          shuffle=True,
                                          **kwargs)

manualSeed = 9302  #random.randint(1, 10000) # fix seed
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)

g_config = get_config()

model_dir = args.model_dir
setupLogger(os.path.join(model_dir, 'log.txt'))
g_config.model_dir = model_dir

criterion = nn.HingeEmbeddingLoss()
model = Siamese()

# load model snapshot
load_path = args.load_path
if load_path is not '':
    snapshot = torch.load(load_path)
    # loadModelState(model, snapshot)
    model.load_state_dict(snapshot['state_dict'])
    logging('Model loaded from {}'.format(load_path))

train_model(model,
            criterion,
            train_loader,
            test_loader,
            g_config,
            use_cuda=False)
Beispiel #12
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--train', default='train_wiki', help='train file')
    parser.add_argument('--val', default='val_wiki', help='val file')
    parser.add_argument('--test', default='test_wiki', help='test file')
    parser.add_argument('--adv', default=None, help='adv file')
    parser.add_argument('--trainN', default=10, type=int, help='N in train')
    parser.add_argument('--N', default=5, type=int, help='N way')
    parser.add_argument('--K', default=5, type=int, help='K shot')
    parser.add_argument('--Q',
                        default=5,
                        type=int,
                        help='Num of query per class')
    parser.add_argument('--batch_size', default=4, type=int, help='batch size')
    parser.add_argument('--train_iter',
                        default=20000,
                        type=int,
                        help='num of iters in training')
    parser.add_argument('--val_iter',
                        default=1000,
                        type=int,
                        help='num of iters in validation')
    parser.add_argument('--test_iter',
                        default=2000,
                        type=int,
                        help='num of iters in testing')
    parser.add_argument('--val_step',
                        default=2000,
                        type=int,
                        help='val after training how many iters')
    parser.add_argument('--model', default='proto', help='model name')
    parser.add_argument('--encoder',
                        default='cnn',
                        help='encoder: cnn or bert')
    parser.add_argument('--max_length',
                        default=128,
                        type=int,
                        help='max length')
    parser.add_argument('--lr', default=1e-1, type=float, help='learning rate')
    parser.add_argument('--weight_decay',
                        default=1e-5,
                        type=float,
                        help='weight decay')
    parser.add_argument('--dropout',
                        default=0.0,
                        type=float,
                        help='dropout rate')
    parser.add_argument('--na_rate',
                        default=0,
                        type=int,
                        help='NA rate (NA = Q * na_rate)')
    parser.add_argument('--grad_iter',
                        default=1,
                        type=int,
                        help='accumulate gradient every x iterations')
    parser.add_argument('--optim',
                        default='sgd',
                        help='sgd / adam / bert_adam')
    parser.add_argument('--hidden_size',
                        default=230,
                        type=int,
                        help='hidden size')
    parser.add_argument('--load_ckpt', default=None, help='load ckpt')
    parser.add_argument('--save_ckpt', default=None, help='save ckpt')
    parser.add_argument('--fp16',
                        action='store_true',
                        help='use nvidia apex fp16')
    parser.add_argument('--only_test', action='store_true', help='only test')
    parser.add_argument('--pair', action='store_true', help='use pair model')
    parser.add_argument('--language', type=str, default='eng', help='language')
    parser.add_argument('--sup_cost',
                        type=int,
                        default=0,
                        help='use sup classifier')

    opt = parser.parse_args()
    trainN = opt.trainN
    N = opt.N
    K = opt.K
    Q = opt.Q
    batch_size = opt.batch_size
    model_name = opt.model
    encoder_name = opt.encoder
    max_length = opt.max_length
    sup_cost = bool(opt.sup_cost)
    print(sup_cost)

    print("{}-way-{}-shot Few-Shot Relation Classification".format(N, K))
    print("model: {}".format(model_name))
    print("encoder: {}".format(encoder_name))
    print("max_length: {}".format(max_length))

    embsize = 50
    if opt.language == 'chn':
        embsize = 100

    if encoder_name == 'cnn':
        try:
            if opt.language == 'chn':
                glove_mat = np.load('./pretrain/chinese_emb/emb.npy')
                glove_word2id = json.load(
                    open('./pretrain/chinese_emb/word2id.json'))
            else:
                glove_mat = np.load('./pretrain/glove/glove_mat.npy')
                glove_word2id = json.load(
                    open('./pretrain/glove/glove_word2id.json'))
        except:
            raise Exception(
                "Cannot find glove files. Run glove/download_glove.sh to download glove files."
            )
        sentence_encoder = CNNSentenceEncoder(glove_mat,
                                              glove_word2id,
                                              max_length,
                                              word_embedding_dim=embsize)
    elif encoder_name == 'bert':
        if opt.pair:
            if opt.language == 'chn':
                sentence_encoder = BERTPAIRSentenceEncoder(
                    'bert-base-chinese',  #'./pretrain/bert-base-uncased',
                    max_length)
            else:
                sentence_encoder = BERTPAIRSentenceEncoder(
                    'bert-base-uncased', max_length)
        else:
            if opt.language == 'chn':
                sentence_encoder = BERTSentenceEncoder(
                    'bert-base-chinese',  #'./pretrain/bert-base-uncased',
                    max_length)
            else:
                sentence_encoder = BERTSentenceEncoder('bert-base-uncased',
                                                       max_length)
    else:
        raise NotImplementedError

    if opt.pair:
        train_data_loader = get_loader_pair(opt.train,
                                            sentence_encoder,
                                            N=trainN,
                                            K=K,
                                            Q=Q,
                                            na_rate=opt.na_rate,
                                            batch_size=batch_size)
        val_data_loader = get_loader_pair(opt.val,
                                          sentence_encoder,
                                          N=N,
                                          K=K,
                                          Q=Q,
                                          na_rate=opt.na_rate,
                                          batch_size=batch_size)
        test_data_loader = get_loader_pair(opt.test,
                                           sentence_encoder,
                                           N=N,
                                           K=K,
                                           Q=Q,
                                           na_rate=opt.na_rate,
                                           batch_size=batch_size)
    else:
        train_data_loader = get_loader(opt.train,
                                       sentence_encoder,
                                       N=trainN,
                                       K=K,
                                       Q=Q,
                                       na_rate=opt.na_rate,
                                       batch_size=batch_size)
        val_data_loader = get_loader(opt.val,
                                     sentence_encoder,
                                     N=N,
                                     K=K,
                                     Q=Q,
                                     na_rate=opt.na_rate,
                                     batch_size=batch_size)
        test_data_loader = get_loader(opt.test,
                                      sentence_encoder,
                                      N=N,
                                      K=K,
                                      Q=Q,
                                      na_rate=opt.na_rate,
                                      batch_size=batch_size)
        if opt.adv:
            adv_data_loader = get_loader_unsupervised(opt.adv,
                                                      sentence_encoder,
                                                      N=trainN,
                                                      K=K,
                                                      Q=Q,
                                                      na_rate=opt.na_rate,
                                                      batch_size=batch_size)

    if opt.optim == 'sgd':
        pytorch_optim = optim.SGD
    elif opt.optim == 'adam':
        pytorch_optim = optim.Adam
    elif opt.optim == 'bert_adam':
        from transformers import AdamW
        pytorch_optim = AdamW
    else:
        raise NotImplementedError
    if opt.adv:
        d = Discriminator(opt.hidden_size)
        framework = FewShotREFramework(train_data_loader,
                                       val_data_loader,
                                       test_data_loader,
                                       adv_data_loader,
                                       adv=opt.adv,
                                       d=d)
    else:
        framework = FewShotREFramework(train_data_loader, val_data_loader,
                                       test_data_loader)

    prefix = '-'.join(
        [model_name, encoder_name, opt.train, opt.val,
         str(N), str(K)])
    if opt.adv is not None:
        prefix += '-adv_' + opt.adv
    if opt.na_rate != 0:
        prefix += '-na{}'.format(opt.na_rate)

    if model_name == 'proto':
        model = Proto(sentence_encoder, hidden_size=opt.hidden_size)
    elif model_name == 'gnn':
        model = GNN(sentence_encoder, N, use_sup_cost=sup_cost)
    elif model_name == 'snail':
        print("HINT: SNAIL works only in PyTorch 0.3.1")
        model = SNAIL(sentence_encoder, N, K)
    elif model_name == 'metanet':
        model = MetaNet(N,
                        K,
                        sentence_encoder.embedding,
                        max_length,
                        use_sup_cost=sup_cost)
    elif model_name == 'siamese':
        model = Siamese(sentence_encoder,
                        hidden_size=opt.hidden_size,
                        dropout=opt.dropout)
    elif model_name == 'pair':
        model = Pair(sentence_encoder, hidden_size=opt.hidden_size)
    else:
        raise NotImplementedError

    if not os.path.exists('checkpoint'):
        os.mkdir('checkpoint')
    ckpt = 'checkpoint/{}.pth.tar'.format(prefix)
    if opt.save_ckpt:
        ckpt = opt.save_ckpt

    if torch.cuda.is_available():
        model.cuda()

    if not opt.only_test:
        if encoder_name == 'bert':
            bert_optim = True
        else:
            bert_optim = False
        framework.train(model,
                        prefix,
                        batch_size,
                        trainN,
                        N,
                        K,
                        Q,
                        pytorch_optim=pytorch_optim,
                        load_ckpt=opt.load_ckpt,
                        save_ckpt=ckpt,
                        na_rate=opt.na_rate,
                        val_step=opt.val_step,
                        fp16=opt.fp16,
                        pair=opt.pair,
                        train_iter=opt.train_iter,
                        val_iter=opt.val_iter,
                        bert_optim=bert_optim,
                        sup_cls=sup_cost)
    else:
        ckpt = opt.load_ckpt

    acc = framework.eval(model,
                         batch_size,
                         N,
                         K,
                         Q,
                         opt.test_iter,
                         na_rate=opt.na_rate,
                         ckpt=ckpt,
                         pair=opt.pair)
    wfile = open('logs/' + ckpt.replace('checkpoint/', '') + '.txt', 'a')
    wfile.write(str(N) + '\t' + str(K) + '\t' + str(acc * 100) + '\n')
    wfile.close()
    print("RESULT: %.2f" % (acc * 100))
def train(use_cuda: bool, n_epochs: int, validate_every: int,
          use_dropout: bool, partitions: Partitions, optimizer_name: str,
          lr: float, wd: float, momentum: bool):
    logger = logging.getLogger('logger')

    no_test = True
    model_path = "./model_output/pairwise/model_{0}"

    partitions.generate_partitions(PairPartition, no_test=no_test)
    training_data = Balanced(partitions.train)

    if validate_every > 0:
        balanced_validation = Balanced(partitions.val)
        training_pairs = AllPairs(partitions.train)
        search_length = training_pairs.n_references
        validation_pairs = AllPairs(partitions.val)
        testing_pairs = AllPairs(partitions.test) if not no_test else None
    else:
        balanced_validation = None
        training_pairs = None
        validation_pairs = None
        testing_pairs = None
        search_length = None

    # get a siamese network, see Siamese class for architecture
    siamese = Siamese(dropout=use_dropout)
    siamese = initialize_weights(siamese, use_cuda)

    if use_cuda:
        siamese = siamese.cuda()

    criterion = BCELoss()
    optimizer = get_optimizer(siamese, optimizer_name, lr, wd, momentum)

    try:
        logger.info("Training network with pairwise loss...")
        progress = TrainingProgress()
        models = training.train_siamese_network(siamese, training_data,
                                                criterion, optimizer, n_epochs,
                                                use_cuda)
        for epoch, (model, training_batch_losses) in enumerate(models):
            utils.network.save_model(model, model_path.format(epoch))

            training_loss = training_batch_losses.mean()
            if validate_every != 0 and epoch % validate_every == 0:
                validation_batch_losses = inference.siamese_loss(
                    model, balanced_validation, criterion, use_cuda)
                validation_loss = validation_batch_losses.mean()

                training_mrr, training_rank = inference.mean_reciprocal_ranks(
                    model, training_pairs, use_cuda)
                val_mrr, val_rank = inference.mean_reciprocal_ranks(
                    model, validation_pairs, use_cuda)

                progress.add_mrr(train=training_mrr, val=val_mrr)
                progress.add_rank(train=training_rank, val=val_rank)
                progress.add_loss(train=training_loss, val=validation_loss)
            else:
                progress.add_mrr(train=np.nan, val=np.nan)
                progress.add_rank(train=np.nan, val=np.nan)
                progress.add_loss(train=training_loss, val=np.nan)

            progress.graph("Siamese", search_length)

        # load weights from best model if we validated throughout
        if validate_every > 0:
            siamese = siamese.train()
            utils.network.load_model(
                siamese, model_path.format(np.argmax(progress.val_mrr)))

        # otherwise just save most recent model
        utils.network.save_model(siamese, model_path.format('best'))
        utils.network.save_model(
            siamese,
            './output/{0}/pairwise'.format(utilities.get_trial_number()))

        if not no_test:
            logger.info(
                "Results from best model generated during training, evaluated on test data:"
            )
            rrs = inference.reciprocal_ranks(siamese, testing_pairs, use_cuda)
            utilities.log_final_stats(rrs)

        progress.pearson(log=True)
        progress.save("./output/{0}/pairwise.pickle".format(
            utilities.get_trial_number()))
        return siamese
    except Exception as e:
        utils.network.save_model(siamese, model_path.format('crash_backup'))
        logger.critical("Exception occurred while training: {0}".format(
            str(e)))
        logger.critical(traceback.print_exc())
        sys.exit()
Beispiel #14
0
from config import get_config
from utilities import loadModelState, loadAndResizeImage, logging, modelSize
from torch.utils.serialization import load_lua
from torch import nn
from models.siamese import Siamese

load_from_torch7 = False

print('Loading model...')
model_dir = 'models/snapshot/'
model_load_path = os.path.join(model_dir, 'snapshot_epoch_1.pt')
gConfig = get_config()
gConfig.model_dir = model_dir

criterion = nn.HingeEmbeddingLoss()
model = Siamese()

package = torch.load(model_load_path)

model.load_state_dict(package['state_dict'])
model.eval()
print('Model loaded from {}'.format(model_load_path))

logging('Model configuration:\n{}'.format(model))

modelSize, nParamsEachLayer = modelSize(model)
logging('Model size: {}\n{}'.format(modelSize, nParamsEachLayer))

params = model.parameters()

for i, a_param in enumerate(params):
            return (stacked, label)
        else:
            stacked = np.hstack((rand_loss, rand_win))
            stacked = torch.from_numpy(stacked).type(torch.FloatTensor)
            label = torch.from_numpy(np.array([0, 1])).type(torch.FloatTensor)
            return (stacked, label)

    def __len__(self):
        return self.length

train_loader = torch.utils.data.DataLoader(TrainSet(1000000),batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(TestSet(100000),batch_size=batch_size, shuffle=True)


print('Buidling model...')
model = Siamese().to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)

e = enumerate(train_loader)
b, (data, label) = next(e)


# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(pred, label):
    BCE = F.binary_cross_entropy(pred, label, size_average=False)
    return BCE


def train(epoch):
    model.train()
    train_loss = 0
Beispiel #16
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--train', default='train_wiki', help='train file')
    parser.add_argument('--val', default='val_wiki', help='val file')
    parser.add_argument('--test', default='test_wiki', help='test file')
    parser.add_argument('--adv', default=None, help='adv file')
    parser.add_argument('--trainN', default=10, type=int, help='N in train')
    parser.add_argument('--N', default=5, type=int, help='N way')
    parser.add_argument('--K', default=5, type=int, help='K shot')
    parser.add_argument('--Q',
                        default=5,
                        type=int,
                        help='Num of query per class')
    parser.add_argument('--batch_size', default=4, type=int, help='batch size')
    parser.add_argument('--train_iter',
                        default=30000,
                        type=int,
                        help='num of iters in training')
    parser.add_argument('--val_iter',
                        default=1000,
                        type=int,
                        help='num of iters in validation')
    parser.add_argument('--test_iter',
                        default=3000,
                        type=int,
                        help='num of iters in testing')
    parser.add_argument('--val_step',
                        default=2000,
                        type=int,
                        help='val after training how many iters')
    parser.add_argument('--model', default='proto', help='model name')
    parser.add_argument('--encoder',
                        default='cnn',
                        help='encoder: cnn or bert')
    parser.add_argument('--max_length',
                        default=128,
                        type=int,
                        help='max length')
    parser.add_argument('--lr', default=1e-1, type=float, help='learning rate')
    parser.add_argument('--weight_decay',
                        default=1e-5,
                        type=float,
                        help='weight decay')
    parser.add_argument('--dropout',
                        default=0.0,
                        type=float,
                        help='dropout rate')
    parser.add_argument('--na_rate',
                        default=0,
                        type=int,
                        help='NA rate (NA = Q * na_rate)')
    parser.add_argument('--grad_iter',
                        default=1,
                        type=int,
                        help='accumulate gradient every x iterations')
    parser.add_argument('--optim',
                        default='sgd',
                        help='sgd / adam / bert_adam')
    parser.add_argument('--hidden_size',
                        default=230,
                        type=int,
                        help='hidden size')
    parser.add_argument('--load_ckpt', default=None, help='load ckpt')
    parser.add_argument('--save_ckpt', default=None, help='save ckpt')
    parser.add_argument('--fp16',
                        action='store_true',
                        help='use nvidia apex fp16')
    parser.add_argument('--only_test', action='store_true', help='only test')
    parser.add_argument('--pair', action='store_true', help='use pair model')

    opt = parser.parse_args()
    trainN = opt.trainN
    N = opt.N
    K = opt.K
    Q = opt.Q
    batch_size = opt.batch_size
    model_name = opt.model
    encoder_name = opt.encoder
    max_length = opt.max_length

    print("{}-way-{}-shot Few-Shot Relation Classification".format(N, K))
    print("model: {}".format(model_name))
    print("encoder: {}".format(encoder_name))
    print("max_length: {}".format(max_length))

    if encoder_name == 'cnn':
        try:
            glove_mat = np.load('./pretrain/glove/glove_mat.npy')
            glove_word2id = json.load(
                open('./pretrain/glove/glove_word2id.json'))
        except:
            raise Exception(
                "Cannot find glove files. Run glove/download_glove.sh to download glove files."
            )
        sentence_encoder = CNNSentenceEncoder(glove_mat, glove_word2id,
                                              max_length)
    elif encoder_name == 'bert':
        if opt.pair:
            sentence_encoder = BERTPAIRSentenceEncoder(
                './pretrain/bert-base-uncased', max_length)
        else:
            sentence_encoder = BERTSentenceEncoder(
                './pretrain/bert-base-uncased', max_length)
    else:
        raise NotImplementedError

    if opt.pair:
        train_data_loader = get_loader_pair(opt.train,
                                            sentence_encoder,
                                            N=trainN,
                                            K=K,
                                            Q=Q,
                                            na_rate=opt.na_rate,
                                            batch_size=batch_size)
        val_data_loader = get_loader_pair(opt.val,
                                          sentence_encoder,
                                          N=N,
                                          K=K,
                                          Q=Q,
                                          na_rate=opt.na_rate,
                                          batch_size=batch_size)
        test_data_loader = get_loader_pair(opt.test,
                                           sentence_encoder,
                                           N=N,
                                           K=K,
                                           Q=Q,
                                           na_rate=opt.na_rate,
                                           batch_size=batch_size)
    else:
        train_data_loader = get_loader(opt.train,
                                       sentence_encoder,
                                       N=trainN,
                                       K=K,
                                       Q=Q,
                                       na_rate=opt.na_rate,
                                       batch_size=batch_size)
        val_data_loader = get_loader(opt.val,
                                     sentence_encoder,
                                     N=N,
                                     K=K,
                                     Q=Q,
                                     na_rate=opt.na_rate,
                                     batch_size=batch_size)
        test_data_loader = get_loader(opt.test,
                                      sentence_encoder,
                                      N=N,
                                      K=K,
                                      Q=Q,
                                      na_rate=opt.na_rate,
                                      batch_size=batch_size)
        if opt.adv:
            adv_data_loader = get_loader_unsupervised(opt.adv,
                                                      sentence_encoder,
                                                      N=trainN,
                                                      K=K,
                                                      Q=Q,
                                                      na_rate=opt.na_rate,
                                                      batch_size=batch_size)

    if opt.optim == 'sgd':
        pytorch_optim = optim.SGD
    elif opt.optim == 'adam':
        pytorch_optim = optim.Adam
    elif opt.optim == 'bert_adam':
        from pytorch_transformers import AdamW
        pytorch_optim = AdamW
    else:
        raise NotImplementedError
    if opt.adv:
        d = Discriminator(opt.hidden_size)
        framework = FewShotREFramework(train_data_loader,
                                       val_data_loader,
                                       test_data_loader,
                                       adv_data_loader,
                                       adv=opt.adv,
                                       d=d)
    else:
        framework = FewShotREFramework(train_data_loader, val_data_loader,
                                       test_data_loader)

    prefix = '-'.join(
        [model_name, encoder_name, opt.train, opt.val,
         str(N), str(K)])
    if opt.adv is not None:
        prefix += '-adv_' + opt.adv
    if opt.na_rate != 0:
        prefix += '-na{}'.format(opt.na_rate)

    if model_name == 'proto':
        model = Proto(sentence_encoder, hidden_size=opt.hidden_size)
    elif model_name == 'gnn':
        model = GNN(sentence_encoder, N)
    elif model_name == 'snail':
        print("HINT: SNAIL works only in PyTorch 0.3.1")
        model = SNAIL(sentence_encoder, N, K)
    elif model_name == 'metanet':
        model = MetaNet(N, K, sentence_encoder.embedding, max_length)
    elif model_name == 'siamese':
        model = Siamese(sentence_encoder,
                        hidden_size=opt.hidden_size,
                        dropout=opt.dropout)
    elif model_name == 'pair':
        model = Pair(sentence_encoder, hidden_size=opt.hidden_size)
    else:
        raise NotImplementedError

    if not os.path.exists('checkpoint'):
        os.mkdir('checkpoint')
    ckpt = 'checkpoint/{}.pth.tar'.format(prefix)
    if opt.save_ckpt:
        ckpt = opt.save_ckpt

    if torch.cuda.is_available():
        model.cuda()

    if not opt.only_test:
        if encoder_name == 'bert':
            bert_optim = True
        else:
            bert_optim = False

        framework.train(model,
                        prefix,
                        batch_size,
                        trainN,
                        N,
                        K,
                        Q,
                        pytorch_optim=pytorch_optim,
                        load_ckpt=opt.load_ckpt,
                        save_ckpt=ckpt,
                        na_rate=opt.na_rate,
                        val_step=opt.val_step,
                        fp16=opt.fp16,
                        pair=opt.pair,
                        train_iter=opt.train_iter,
                        val_iter=opt.val_iter,
                        bert_optim=bert_optim)
    else:
        ckpt = opt.load_ckpt

    acc = 0
    his_acc = []
    total_test_round = 5
    for i in range(total_test_round):
        cur_acc = framework.eval(model,
                                 batch_size,
                                 N,
                                 K,
                                 Q,
                                 opt.test_iter,
                                 na_rate=opt.na_rate,
                                 ckpt=ckpt,
                                 pair=opt.pair)
        his_acc.append(cur_acc)
        acc += cur_acc
    acc /= total_test_round
    nhis_acc = np.array(his_acc)
    error = nhis_acc.std() * 1.96 / (nhis_acc.shape[0]**0.5)
    print("RESULT: %.2f\\pm%.2f" % (acc * 100, error * 100))

    result_file = open('./result.txt', 'a+')
    result_file.write(
        "test data: %12s | model: %45s | acc: %.6f\n | error: %.6f\n" %
        (opt.test, prefix, acc, error))

    result_file = open('./result_detail.txt', 'a+')
    result_detail = {
        'test': opt.test,
        'model': prefix,
        'acc': acc,
        'his': his_acc
    }
    result_file.write("%s\n" % (json.dumps(result_detail)))