Пример #1
0
def pred_prob(arg_path, field_path, pth_path, doc, device=torch.device('cpu')):

    # Load args
    # with open(arg_path) as f:
    #     args = json.load(f)['args']
    arg_path = os.path.join(
        'https://raw.githubusercontent.com/qianyingw/rob-pome/master/rob-app',
        arg_path)
    with urllib.request.urlopen(arg_path) as url:
        args = json.loads(url.read().decode())['args']

    # Load TEXT field
    field_url = os.path.join(
        'https://github.com/qianyingw/rob-pome/raw/master/rob-app', field_path)
    field_path = wget.download(field_url)
    with open(field_path, "rb") as fin:
        TEXT = dill.load(fin)
    os.remove(field_path)

    unk_idx = TEXT.vocab.stoi[TEXT.unk_token]  # 0
    pad_idx = TEXT.vocab.stoi[TEXT.pad_token]  # 1

    # Load model
    if args['net_type'] == 'cnn':
        sizes = args['filter_sizes'].split(',')
        sizes = [int(s) for s in sizes]
        model = ConvNet(vocab_size=args['max_vocab_size'] + 2,
                        embedding_dim=args['embed_dim'],
                        n_filters=args['num_filters'],
                        filter_sizes=sizes,
                        output_dim=2,
                        dropout=args['dropout'],
                        pad_idx=pad_idx,
                        embed_trainable=args['embed_trainable'],
                        batch_norm=args['batch_norm'])

    if args['net_type'] == 'attn':
        model = AttnNet(vocab_size=args['max_vocab_size'] + 2,
                        embedding_dim=args['embed_dim'],
                        rnn_hidden_dim=args['rnn_hidden_dim'],
                        rnn_num_layers=args['rnn_num_layers'],
                        output_dim=2,
                        bidirection=args['bidirection'],
                        rnn_cell_type=args['rnn_cell_type'],
                        dropout=args['dropout'],
                        pad_idx=pad_idx,
                        embed_trainable=args['embed_trainable'],
                        batch_norm=args['batch_norm'],
                        output_attn=False)

    # Load checkpoint
    pth_url = os.path.join(
        'https://github.com/qianyingw/rob-pome/raw/master/rob-app', pth_path)
    pth_path = wget.download(pth_url)
    checkpoint = torch.load(pth_path, map_location=device)
    os.remove(pth_path)
    state_dict = checkpoint['state_dict']
    model.load_state_dict(state_dict, strict=False)
    model.cpu()

    # Load pre-trained embedding
    pretrained_embeddings = TEXT.vocab.vectors
    model.embedding.weight.data.copy_(pretrained_embeddings)
    model.embedding.weight.data[unk_idx] = torch.zeros(
        args['embed_dim'])  # Zero the initial weights for <unk> tokens
    model.embedding.weight.data[pad_idx] = torch.zeros(
        args['embed_dim'])  # Zero the initial weights for <pad> tokens

    # Tokenization
    tokens = [tok.text.lower() for tok in nlp.tokenizer(doc)]
    idx = [TEXT.vocab.stoi[t] for t in tokens]

    while len(idx) < args['max_token_len']:
        idx = idx + [1] * args['max_token_len']

    if len(idx) > args['max_token_len']:
        idx = idx[:args['max_token_len']]

    # Prediction
    model.eval()
    doc_tensor = torch.LongTensor(idx).to(device)
    doc_tensor = doc_tensor.unsqueeze(
        1)  # bec AttnNet input shape is [seq_len, batch_size]
    probs = model(doc_tensor)
    probs = probs.data.cpu().numpy()[0]
    # print("Prob of RoB reported: {:.4f}".format(probs[1]))

    return probs[1]
Пример #2
0
        # zero the parameter gradients
        optimizer.zero_grad()

        input_patch = input_patch.to(device)
        gt_patch = gt_patch.to(device)

        # forward + backward + optimize
        output = net(input_patch)
        loss = criterion(output, gt_patch)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if cnt == (num_ids - 1):
            print("epoch: %d, loss: %.4f" % (epoch, running_loss / num_ids))

        if epoch % save_freq == 0:
            if not os.path.isdir(result_dir + '%04d' % epoch):
                os.makedirs(result_dir + '%04d' % epoch)

            temp = np.concatenate((gt_patch.cpu().detach().numpy()[0, :, :, :],
                                   output.cpu().detach().numpy()[0, :, :, :]),
                                  axis=1)
            scipy.misc.toimage(temp * 255, high=255, low=0, cmin=0,
                               cmax=255).save(result_dir +
                                              '%04d/%05d_00_train_%d.jpg' %
                                              (epoch, train_id, ratio))

    torch.save(net.cpu(), checkpoint_dir + 'sid_torchversion_model.ckpt')