예제 #1
0
    def __init__(self):
        parser = get_lm_generator_parser()
        self.args = parser.parse_args()
        print(self.args)
        if self.args.temperature < 1e-3:
            parser.error("--temperature has to be greater or equal 1e-3")

        seed_torch(self.args)
예제 #2
0
 def __init__(self):
     parser = get_lm_evaluator_parser()
     self.args = parser.parse_args()
     print(self.args)
     seed_torch(self.args)
     self.vocab = Vocab(self.args.vocab_file)
     self.ntokens = self.vocab.get_size()
     self.data = LMEvalDataset(os.path.join(self.args.data),
                               self.args.vocab_file)
     self.out = open(self.args.outf, "w")
예제 #3
0
    def __init__(self):
        parser = get_parser()
        self.args = parser.parse_args()
        self.args.gpu = self.args.gpu and torch.cuda.is_available()

        seed_torch(self.args)


        if self.args.experiment_name is None:
            self.args.experiment_name = get_experiment_name(self.args)

        self.resuming = os.path.exists(os.path.join(self.args.logs_dir, self.args.experiment_name + '.log'))
        self.checkpoint = Checkpoint(self)
        self.num_classes = 2
        self.meter = Meter(self.num_classes)
        self.writer = Logger(self.args)
        self.writer.write(self.args)
        self.timer = Timer()
        self.load_datasets()
        self.weights = self.get_imbalance()
예제 #4
0
    def __init__(self):
        parser = get_lm_parser()
        self.args = parser.parse_args()

        seed_torch(self.args)
        print("Loading datasets")
        self.train_data = LMDataset(os.path.join(self.args.data, 'train.tsv'),
                                    self.args.vocab_file)
        print("Train dataset loaded")

        self.val_data = LMDataset(os.path.join(self.args.data, 'valid.tsv'),
                                  self.args.vocab_file)
        print("Val dataset loaded")

        self.test_data = LMDataset(os.path.join(self.args.data, 'test.tsv'),
                                   self.args.vocab_file)
        print("Test dataset loaded")

        self.args.vocab_size = self.train_data.get_vocab_size()
        self.args.gpu = self.args.gpu and torch.cuda.is_available()

        print("Created dataloaders")
        self.train_loader = batchify(self.train_data.get_tokens(),
                                     self.args.batch_size, self.args)
        self.val_loader = batchify(self.val_data.get_tokens(),
                                   self.args.batch_size, self.args)

        self.test_loader = batchify(self.test_data.get_tokens(),
                                    self.args.batch_size, self.args)

        if self.args.experiment_name is None:
            self.args.experiment_name = get_lm_experiment_name(self.args)
        self.checkpoint = Checkpoint(self)
        self.writer = Logger(self.args)
        self.writer.write(self.args)
        self.timer = Timer()
예제 #5
0
def test(args):
    vocab_path = args.vocab_file
    dataset_path = args.dataset_path
    gpu = args.gpu

    vocab = Vocab(vocab_path, True)
    dataset = AcceptabilityDataset(args, dataset_path, vocab)

    seed_torch(args)

    if gpu:
        model = torch.load(args.model_file)
        embedding = torch.load(args.embedding_file)
    else:
        model = torch.load(args.model_file,
                           map_location=lambda storage, loc: storage)
        embedding = torch.load(args.embedding_file,
                               map_location=lambda storage, loc: storage)

    loader = torch.utils.data.DataLoader(dataset,
                                         batch_size=1,
                                         pin_memory=gpu,
                                         shuffle=False)

    meter = Meter(2)

    model.eval()
    embedding.eval()
    outputs = []

    for data in loader:
        x, y, _ = data
        x, y = Variable(x).long(), Variable(y)

        if gpu:
            x = x.cuda()
            y = y.cuda()

        x = embedding(x)

        output = model(x)

        if type(output) == tuple:
            output = output[0]
        out_float = output.squeeze()
        output = (out_float > 0.5).long()
        # outputs.append(int(output))
        outputs.append(float(out_float))

        if not gpu:
            output = output.unsqueeze(0)

        meter.add(output.data, y.data)

    print("Matthews %.5f, Accuracy: %.5f" %
          (meter.matthews(), meter.accuracy()))
    if args.output_file != None:
        out_file = open(args.output_file, "w")
        for x in outputs:
            out_file.write(str(x) + "\n")
        out_file.close()