Exemplo n.º 1
0
def load_vanilla(num_classes, device, args):
    encoder = Encoder(use_stn=args.use_stn).to(device)
    decoder = AttentionDecoder(
        hidden_dim=256,
        attention_dim=256,
        y_dim=num_classes,
        encoder_output_dim=512,
        f_lookup_ts="/home/dataset/TR/synth_cn/lookup.pt").to(
            device)  # y_dim for classes_num
    decorate_model(encoder, is_training=False, device=device)
    decorate_model(decoder, is_training=False, device=device)
    checkpoint = torch.load(args.pre_ocr)
    encoder.load_state_dict(checkpoint["state_dict"]["encoder"])
    decoder.load_state_dict(checkpoint["state_dict"]["decoder"], strict=False)
    return encoder, decoder
Exemplo n.º 2
0
    def __init__(self, vocab_size, wordvec_size, hidden_size):
        V, D, H = vocab_size, wordvec_size, hidden_size
        self.encoder = Encoder(V, D, H)
        self.decoder = PeekyDecoder(V, D, H)
        self.softmax = TimeSoftmaxWithLoss()

        self.params = self.encoder.params + self.decoder.params
        self.grads = self.encoder.grads + self.decoder.grads
Exemplo n.º 3
0
 def __init__(self,
              input: list,
              hidden_sz: int,
              output: list,
              embed_sz: int = 4,
              num_layers: int = 4):
     super(AutoEncoderHandler, self).__init__()
     self.input = input
     self.output = output
     self.input_sz = len(input)
     self.output_sz = len(output)
     self.hidden_sz = hidden_sz
     self.embed_sz = embed_sz
     self.num_layers = num_layers
     self.encoder = Encoder(self.input_sz, self.hidden_sz, self.embed_sz,
                            num_layers)
     self.decoder = Decoder(self.output_sz, self.hidden_sz, self.output_sz,
                            self.embed_sz, num_layers)
Exemplo n.º 4
0
 def load_model(self, weights, device):
     INPUT_DIM = len(self.SRC.vocab)
     OUTPUT_DIM = len(self.TRG.vocab)
     enc = Encoder(INPUT_DIM, HID_DIM, ENC_LAYERS, ENC_HEADS, ENC_PF_DIM,
                   ENC_DROPOUT, device)
     dec = Decoder(OUTPUT_DIM, HID_DIM, DEC_LAYERS, DEC_HEADS, DEC_PF_DIM,
                   DEC_DROPOUT, device)
     SRC_PAD_IDX = self.SRC.vocab.stoi[self.SRC.pad_token]
     TRG_PAD_IDX = self.TRG.vocab.stoi[self.TRG.pad_token]
     model = Seq2Seq(enc, dec, SRC_PAD_IDX, TRG_PAD_IDX, device).to(device)
     model.load_state_dict(torch.load(weights))
     return model
Exemplo n.º 5
0
class AutoEncoderHandler():
    def __init__(self,
                 input: list,
                 hidden_sz: int,
                 output: list,
                 embed_sz: int = 4,
                 num_layers: int = 4):
        super(AutoEncoderHandler, self).__init__()
        self.input = input
        self.output = output
        self.input_sz = len(input)
        self.output_sz = len(output)
        self.hidden_sz = hidden_sz
        self.embed_sz = embed_sz
        self.num_layers = num_layers
        self.encoder = Encoder(self.input_sz, self.hidden_sz, self.embed_sz,
                               num_layers)
        self.decoder = Decoder(self.output_sz, self.hidden_sz, self.output_sz,
                               self.embed_sz, num_layers)

    def encode_decode(self,
                      input: torch.Tensor,
                      address: str,
                      break_on_eos: bool = True):
        hidden = None
        input_len = len(input)
        for i in range(len(input)):
            _, hidden = self.encoder.forward(input[i].unsqueeze(0), hidden)

        input = torch.LongTensor([self.output.index(SOS)]).to(DEVICE)
        sample = self.output.index(SOS)
        ret = []
        samples = []

        if break_on_eos:
            while sample != self.output.index(EOS):
                output, hidden = self.decoder.forward(input, hidden)
                sample = pyro.sample(f"{address}_{i}",
                                     dist.Categorical(output)).item()
                ret.append(self.output[sample])
                samples.append(sample)
                input = torch.LongTensor(sample).to(DEVICE)
        else:
            for i in range(input_len):
                output, hidden = self.decoder.forward(input, hidden)
                sample = pyro.sample(f"{address}_{i}",
                                     dist.Categorical(output)).item()
                ret.append(self.output[sample])
                samples.append(sample)
                input = torch.LongTensor([sample]).to(DEVICE)

        return samples, ret
Exemplo n.º 6
0
def load_vanilla(num_classes, device, args):
    encoder = Encoder(use_stn=args.use_stn).to(device)
    decoder = AttentionDecoder(
        hidden_dim=256,
        attention_dim=256,
        y_dim=num_classes,
        encoder_output_dim=512,
        f_lookup_ts="/home/dataset/TR/synth_cn/lookup.pt").to(
            device)  # y_dim for classes_num
    lm = AttentionNet2(input_size=200,
                       hidden_size=512,
                       depth=3,
                       head=5,
                       num_classes=num_classes - 2,
                       k=8)
    decorate_model(encoder, is_training=False, device=device)
    decorate_model(decoder, is_training=False, device=device)
    decorate_model(lm, is_training=False, device=device)
    checkpoint = torch.load(args.ocr)
    encoder.load_state_dict(checkpoint["state_dict"]["encoder"])
    decoder.load_state_dict(checkpoint["state_dict"]["decoder"])
    lm.load_state_dict(torch.load(args.nlp))
    return encoder, decoder, lm
Exemplo n.º 7
0
    def __init__(self,
                 num_layers: int = 2,
                 hidden_sz: int = 64,
                 peak_prob: float = 0.9):
        super().__init__()
        # Model neural nets instantiation
        self.model_fn_lstm = Decoder(LETTERS_COUNT,
                                     hidden_sz,
                                     LETTERS_COUNT,
                                     num_layers=num_layers)
        # Guide neural nets instantiation
        self.guide_fn_lstm = Decoder(LETTERS_COUNT,
                                     hidden_sz,
                                     LETTERS_COUNT,
                                     num_layers=num_layers)
        # Instantiate encoder
        self.encoder_lstm = Encoder(PRINTABLES_COUNT,
                                    hidden_sz,
                                    num_layers=num_layers)

        # Hyperparameters
        self.peak_prob = peak_prob
        self.num_layers = num_layers
        self.hidden_sz = hidden_sz
Exemplo n.º 8
0
def create_model(vocab_size):
    embedding = nn.Embedding(vocab_size, config.hidden_size) \
            if config.single_embedding else None

    encoder = Encoder(vocab_size, config.hidden_size, \
            n_layers = config.n_encoder_layers, dropout=config.dropout)


    decoder = Decoder(config.hidden_size, vocab_size,\
            n_layers = config.n_decoder_layers, dropout=config.dropout)

    model = Seq2Seq(encoder=encoder,
                    decoder=decoder,
                    max_length=config.max_length)

    if torch.cuda.is_available() and config.use_cuda:
        model.cuda()

    return model
Exemplo n.º 9
0
 def __init__(self, **kwargs):
     super(Seq2Seq, self).__init__()
     
     # Define the hyper-parameters or arguments
     self.batch_size = kwargs['batch_size']
     self.enc_max_len = kwargs['enc_max_len']
     self.dec_max_len = kwargs['dec_max_len']
     self.enc_unit = kwargs['enc_unit']
     self.dec_unit = kwargs['dec_unit']
     self.embed_dim = kwargs['embed_dim']
     self.dropout_rate = kwargs['dropout_rate']
     self.enc_vocab_size = kwargs['enc_vocab_size']
     self.dec_vocab_size = kwargs['dec_vocab_size']
     self.sos_token = kwargs['dec_sos_token']
     
     # Define the encoder and decoder layer
     self.encoder = Encoder(self.batch_size, self.enc_max_len, self.enc_unit,
                            self.dropout_rate, self.enc_vocab_size, self.embed_dim)
     self.decoder = Decoder(self.batch_size, self.dec_max_len, self.dec_unit,
                            self.embed_dim, self.dec_vocab_size, self.dropout_rate)        
Exemplo n.º 10
0
                                                        batch_size=batch_size,
                                                        device=device)

INPUT_DIM = len(SRC.vocab)
OUTPUT_DIM = len(TRG.vocab)
HID_DIM = config['HID_DIM']
ENC_LAYERS = config['ENC_LAYERS']
DEC_LAYERS = config['DEC_LAYERS']
ENC_HEADS = config['ENC_HEADS']
DEC_HEADS = config['DEC_HEADS']
ENC_PF_DIM = config['ENC_PF_DIM']
DEC_PF_DIM = config['DEC_PF_DIM']
ENC_DROPOUT = config['ENC_DROPOUT']
DEC_DROPOUT = config['DEC_DROPOUT']

enc = Encoder(INPUT_DIM, HID_DIM, ENC_LAYERS, ENC_HEADS, ENC_PF_DIM,
              ENC_DROPOUT, device)

dec = Decoder(OUTPUT_DIM, HID_DIM, DEC_LAYERS, DEC_HEADS, DEC_PF_DIM,
              DEC_DROPOUT, device)

SRC_PAD_IDX = SRC.vocab.stoi[SRC.pad_token]
TRG_PAD_IDX = TRG.vocab.stoi[TRG.pad_token]

model = Seq2Seq(enc, dec, SRC_PAD_IDX, TRG_PAD_IDX, device).to(device)

print(f'The model has {count_parameters(model):,} trainable parameters')

model.apply(initialize_weights)

if config['train_embeddings']:
    model.decoder.tok1_embedding.load_state_dict(glovemodel.wi.state_dict())
Exemplo n.º 11
0
class NameParser():
    """
    Generates names using a separate LSTM for first, middle, last name and a neural net
    using ELBO to parameterize NN for format classification.

    input_size: Should be the number of letters to allow
    hidden_size: Size of the hidden dimension in LSTM
    num_layers: Number of hidden layers in LSTM
    hidden_sz: Hidden layer size for LSTM RNN
    peak_prob: The max expected probability
    """
    def __init__(self,
                 num_layers: int = 2,
                 hidden_sz: int = 64,
                 peak_prob: float = 0.9):
        super().__init__()
        # Model neural nets instantiation
        self.model_fn_lstm = Decoder(LETTERS_COUNT,
                                     hidden_sz,
                                     LETTERS_COUNT,
                                     num_layers=num_layers)
        # Guide neural nets instantiation
        self.guide_fn_lstm = Decoder(LETTERS_COUNT,
                                     hidden_sz,
                                     LETTERS_COUNT,
                                     num_layers=num_layers)
        # Instantiate encoder
        self.encoder_lstm = Encoder(PRINTABLES_COUNT,
                                    hidden_sz,
                                    num_layers=num_layers)

        # Hyperparameters
        self.peak_prob = peak_prob
        self.num_layers = num_layers
        self.hidden_sz = hidden_sz

    def model(self, X_u: list, X_s: list, Z_s: dict, observations=None):
        """
        Model for generating names representing p(x,z)
        x: Training data (name string)
        z: Optionally supervised latent values (dictionary of name/format values)
        """
        pyro.module("model_fn_lstm", self.model_fn_lstm)

        formatted_X_u = strings_to_tensor(X_u, MAX_NAME_LENGTH,
                                          printable_to_index)
        formatted_X_s = strings_to_tensor(X_s, MAX_NAME_LENGTH,
                                          printable_to_index)

        with pyro.plate("sup_batch", len(X_s)):
            _, first_names = self.generate_name_supervised(
                self.model_fn_lstm,
                FIRST_NAME_ADD,
                len(X_s),
                observed=Z_s[FIRST_NAME_ADD])
            full_names = list(
                map(lambda name: pad_string(name, MAX_NAME_LENGTH),
                    first_names))
            probs = strings_to_probs(full_names,
                                     MAX_NAME_LENGTH,
                                     printable_to_index,
                                     true_index_prob=self.peak_prob)
            pyro.sample("sup_output",
                        dist.OneHotCategorical(probs.transpose(0,
                                                               1)).to_event(1),
                        obs=formatted_X_s.transpose(0, 1))

        with pyro.plate("unsup_batch", len(X_u)):
            _, first_names = self.generate_name(self.model_fn_lstm,
                                                FIRST_NAME_ADD, len(X_u))
            full_names = list(
                map(lambda name: pad_string(name, MAX_NAME_LENGTH),
                    first_names))
            probs = strings_to_probs(full_names,
                                     MAX_NAME_LENGTH,
                                     printable_to_index,
                                     true_index_prob=self.peak_prob)
            pyro.sample("unsup_output",
                        dist.OneHotCategorical(probs.transpose(0,
                                                               1)).to_event(1),
                        obs=formatted_X_u.transpose(0, 1))

        return full_names

    def guide(self, X_u: list, X_s: list, Z_s: dict, observations=None):
        """
        Guide for approximation of the posterior q(z|x)
        x: Training data (name string)
        z: Optionally supervised latent values (dictionary of name/format values)
        """

        pyro.module("guide_fn_lstm", self.guide_fn_lstm)
        pyro.module("encoder_lstm", self.encoder_lstm)

        if observations is None:
            formatted_X_u = strings_to_tensor(X_u, MAX_NAME_LENGTH,
                                              printable_to_index)
        else:
            formatted_X_u = observations['unsup_output'].transpose(0, 1)

        hidd_cell_states = self.encoder_lstm.init_hidden(batch_size=len(X_u))
        for i in range(formatted_X_u.shape[0]):
            _, hidd_cell_states = self.encoder_lstm.forward(
                formatted_X_u[i].unsqueeze(0), hidd_cell_states)

        with pyro.plate("unsup_batch", len(X_u)):
            _, first_names = self.generate_name(
                self.guide_fn_lstm,
                FIRST_NAME_ADD,
                len(X_u),
                hidd_cell_states=hidd_cell_states,
                sample=False)

        return first_names

    def infer(self, X_u: list):
        formatted_X_u = strings_to_tensor(X_u, MAX_NAME_LENGTH,
                                          printable_to_index)
        hidd_cell_states = self.encoder_lstm.init_hidden(batch_size=len(X_u))
        for i in range(formatted_X_u.shape[0]):
            _, hidd_cell_states = self.encoder_lstm.forward(
                formatted_X_u[i].unsqueeze(0), hidd_cell_states)
        _, first_names = self.generate_name(self.guide_fn_lstm,
                                            FIRST_NAME_ADD,
                                            len(X_u),
                                            hidd_cell_states=hidd_cell_states,
                                            sample=False)
        return first_names

    def generate(self):
        _, name = self.generate_name(self.model_fn_lstm, FIRST_NAME_ADD, 1)
        return name

    def generate_name(self,
                      lstm: Decoder,
                      address: str,
                      batch_size: int,
                      hidd_cell_states: tuple = None,
                      sample: bool = True):
        """
        lstm: Decoder associated with name being generated
        address: The address to correlate pyro distribution with latent variables
        hidd_cell_states: Previous LSTM hidden state or empty hidden state
        max_name_length: The max name length allowed
        """
        # If no hidden state is provided, initialize it with all 0s
        if hidd_cell_states == None:
            hidd_cell_states = lstm.init_hidden(batch_size=batch_size)

        input_tensor = strings_to_tensor([SOS] * batch_size, 1,
                                         letter_to_index)
        names = [''] * batch_size

        for index in range(MAX_NAME_LENGTH):
            char_dist, hidd_cell_states = lstm.forward(input_tensor,
                                                       hidd_cell_states)

            if sample:
                # Next LSTM input is the sampled character
                input_tensor = pyro.sample(f"unsup_{address}_{index}",
                                           dist.OneHotCategorical(char_dist))
                chars_at_indexes = list(
                    map(lambda index: MODEL_CHARS[int(index.item())],
                        torch.argmax(input_tensor, dim=2).squeeze(0)))
            else:
                # Next LSTM input is the character with the highest probability of occurring
                pyro.sample(f"unsup_{address}_{index}",
                            dist.OneHotCategorical(char_dist))
                chars_at_indexes = list(
                    map(lambda index: MODEL_CHARS[int(index.item())],
                        torch.argmax(char_dist, dim=2).squeeze(0)))
                input_tensor = strings_to_tensor(chars_at_indexes, 1,
                                                 letter_to_index)

            # Add sampled characters to names
            for i, char in enumerate(chars_at_indexes):
                names[i] += char

        # Discard everything after EOS character
        # names = list(map(lambda name: name[:name.find(EOS)] if name.find(EOS) > -1 else name, names))
        return hidd_cell_states, names

    def generate_name_supervised(self,
                                 lstm: Decoder,
                                 address: str,
                                 batch_size: int,
                                 observed: list = None):
        """
        lstm: Decoder associated with name being generated
        address: The address to correlate pyro distribution with latent variables
        observed: Dictionary of name/format values
        """
        hidd_cell_states = lstm.init_hidden(batch_size=batch_size)
        observed_tensor = strings_to_tensor(observed, MAX_NAME_LENGTH,
                                            letter_to_index)
        input_tensor = strings_to_tensor([SOS] * batch_size, 1,
                                         letter_to_index)
        names = [''] * batch_size

        for index in range(MAX_NAME_LENGTH):
            char_dist, hidd_cell_states = lstm.forward(input_tensor,
                                                       hidd_cell_states)
            input_tensor = pyro.sample(f"sup_{address}_{index}",
                                       dist.OneHotCategorical(char_dist),
                                       obs=observed_tensor[index].unsqueeze(0))
            # Sampled char should be an index not a one-hot
            chars_at_indexes = list(
                map(lambda index: MODEL_CHARS[int(index.item())],
                    torch.argmax(input_tensor, dim=2).squeeze(0)))
            # Add sampled characters to names
            for i, char in enumerate(chars_at_indexes):
                names[i] += char

        # Discard everything after EOS character
        names = list(
            map(
                lambda name: name[:name.find(EOS)]
                if name.find(EOS) > -1 else name, names))
        return hidd_cell_states, names

    def load_checkpoint(self,
                        folder="nn_model",
                        filename="checkpoint.pth.tar"):
        filepath = os.path.join(folder, filename)
        if not os.path.exists(filepath):
            raise Exception(f"No model in path {folder}")
        save_content = torch.load(filepath, map_location=DEVICE)
        self.model_fn_lstm.load_state_dict(save_content['model_fn_lstm'])
        self.guide_fn_lstm.load_state_dict(save_content['guide_fn_lstm'])
        self.encoder_lstm.load_state_dict(save_content['encoder_lstm'])

    def save_checkpoint(self,
                        folder="nn_model",
                        filename="checkpoint.pth.tar"):
        filepath = os.path.join(folder, filename)
        if not os.path.exists(folder):
            os.mkdir(folder)
        save_content = {
            'model_fn_lstm': self.model_fn_lstm.state_dict(),
            'guide_fn_lstm': self.guide_fn_lstm.state_dict(),
            'encoder_lstm': self.encoder_lstm.state_dict()
        }
        torch.save(save_content, filepath)
Exemplo n.º 12
0
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=BATCH_SIZE,
                                          shuffle=False,
                                          collate_fn=collate_fn,
                                          drop_last=True)

INPUT_DIM = len(src_vocab)
OUTPUT_DIM = len(tag_vocab)
ENC_EMB_DIM = 256
DEC_EMB_DIM = 256
HID_DIM = 512
N_LAYERS = 2
ENC_DROPOUT = 0.5
DEC_DROPOUT = 0.5

enc = Encoder(INPUT_DIM, ENC_EMB_DIM, HID_DIM, N_LAYERS, ENC_DROPOUT)
dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, HID_DIM, N_LAYERS, DEC_DROPOUT)
model = Seq2Seq(enc, dec)


# init weights
def init_weights(m):
    for name, param in m.named_parameters():
        nn.init.uniform_(param.data, -0.08, 0.08)


model.apply(init_weights)


# calculate the number of trainable parameters in the model
def count_parameters(model):
Exemplo n.º 13
0
def main(args):
    input_size = [64, 256] if args.use_stn else [32, 128]
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_trsf = transforms.Compose([transforms.ToTensor(), normalize])
    train_loader = get_dataloader(args.training_data, alphabet, input_size,
                                  train_trsf, args.batch_size)
    test_trsf = transforms.Compose([transforms.ToTensor(), normalize])
    test_loader = get_dataloader(args.test_data,
                                 alphabet,
                                 input_size,
                                 test_trsf,
                                 args.batch_size,
                                 is_train=False)

    encoder = Encoder(use_stn=args.use_stn).cuda()
    decoder = AttentionDecoder(
        hidden_dim=256,
        attention_dim=256,
        y_dim=converter.num_classes,
        encoder_output_dim=512,
        f_lookup_ts="/home/dataset/TR/synth_cn/lookup.pt").cuda(
        )  # output_size for classes_num
    encoder_optimizer = optim.Adadelta(encoder.parameters(), lr=args.lr)
    decoder_optimizer = optim.Adadelta(decoder.parameters(), lr=args.lr)
    optimizers = [encoder_optimizer, decoder_optimizer]
    lr_step = [100000, 300000, 500000]
    # lr_step = [100000, 200000]
    encoder_scheduler = optim.lr_scheduler.MultiStepLR(encoder_optimizer,
                                                       lr_step,
                                                       gamma=0.1)
    decoder_scheduler = optim.lr_scheduler.MultiStepLR(decoder_optimizer,
                                                       lr_step,
                                                       gamma=0.1)
    criterion = SequenceCrossEntropyLoss()

    step, total_loss, best_res = 1, 0, 0
    # For fine-tuning
    # checkpoint = torch.load('./logs/model-120000.pth')
    # encoder.load_state_dict(checkpoint["state_dict"]["encoder"])
    # decoder.load_state_dict(checkpoint["state_dict"]["decoder"], strict=False)

    if args.restore_step > 0:
        step = args.restore_step
        load_state(args.logs_dir, step, encoder, decoder, optimizers)

    sys.stdout = Logger(os.path.join(args.logs_dir, 'log.txt'))
    train_tfLogger = SummaryWriter(os.path.join(args.logs_dir, 'train'))
    test_tfLogger = SummaryWriter(os.path.join(args.logs_dir, 'test'))

    # start training
    while True:
        for batch_idx, (imgs, (targets, targets_len),
                        idx) in enumerate(train_loader):

            input_data, targets, targets_len = imgs.cuda(), targets.cuda(
            ), targets_len.cuda()
            encoder_optimizer.zero_grad(), decoder_optimizer.zero_grad()

            loss, recitified_img = batch_train(input_data, targets,
                                               targets_len, encoder, decoder,
                                               criterion, 1.0)
            encoder_optimizer.step(), decoder_optimizer.step()
            encoder_scheduler.step(), decoder_scheduler.step()
            total_loss += loss

            if step % 500 == 0:
                print('==' * 30)
                preds, _ = batch_test(input_data, encoder, decoder)
                print('preds: ', converter.decode(preds.cpu().numpy()))
                print('==' * 30)
                print('label: ',
                      converter.decode(targets.permute(1, 0).cpu().numpy()))
                encoder.train(), decoder.train()

            if step % args.log_interval == 0:
                print('{} step:{}\tLoss: {:.6f}'.format(
                    datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S'),
                    step, total_loss / args.log_interval))
                if train_tfLogger is not None:
                    """
                    x = vutils.make_grid(input_data.cpu())
                    train_tfLogger.add_image('train/input_img', x, step)
                    if args.use_stn:
                        x = vutils.make_grid(recitified_img.cpu())
                        train_tfLogger.add_image('train/recitified_img', x, step)
                    """
                    for param_group in encoder_optimizer.param_groups:
                        lr = param_group['lr']
                    info = {
                        'loss': total_loss / args.log_interval,
                        'learning_rate': lr
                    }
                    for tag, value in info.items():
                        train_tfLogger.add_scalar(tag, value, step)
                total_loss = 0
            if step % args.save_interval == 0:
                # save params
                save_state(args.logs_dir, step, encoder, decoder, optimizers)

                # Test after an args.save_interval
                res = test(encoder,
                           decoder,
                           test_loader,
                           step=step,
                           tfLogger=test_tfLogger)
                is_best = res >= best_res
                best_res = max(res, best_res)
                print(
                    '\nFinished step {:3d}  TestAcc: {:.4f}  best: {:.2%}{}\n'.
                    format(step, res, best_res, ' *' if is_best else ''))
                encoder.train(), decoder.train()

            step += 1

    # Close the tf logger
    train_tfLogger.close()
    test_tfLogger.close()
Exemplo n.º 14
0
    parser.add_argument('--batch',
                        type=int,
                        default=8000,
                        help='the start id of a batch')
    parser.add_argument('--use_stn', action='store_true', default=False)
    args = parser.parse_args()

    os.environ['CUDA_VISIBLE_DEVICES'] = args.idc
    device = torch.device('cuda:0')

    f_alphabet = "/home/dataset/TR/synth_cn/alphabet.json"
    alphabet = get_alphabet(f_alphabet)
    label_map = Ids2Str(alphabet)

    # get model
    encoder = Encoder(use_stn=args.use_stn).cuda()
    decoder = AttentionDecoder(
        hidden_dim=256,
        attention_dim=256,
        y_dim=label_map.num_classes,
        encoder_output_dim=512,
        f_lookup_ts="/home/dataset/TR/synth_cn/lookup.pt").cuda(
        )  # y_dim for classes_num

    decorate_model(encoder)
    decorate_model(decoder)
    checkpoint = torch.load(args.param)
    encoder.load_state_dict(checkpoint["state_dict"]["encoder"])
    decoder.load_state_dict(checkpoint["state_dict"]["decoder"], strict=False)

    if args.img: