Esempio n. 1
0
#############################################################
ctx = [mx.cpu(i) for i in range(cpu_count)]
##############################################################
train_ds = IAMDataset("form_bb",
                      output_data="bb",
                      output_parse_method=detection_box,
                      train=True)
print("Number of training samples: {}".format(len(train_ds)))

test_ds = IAMDataset("form_bb",
                     output_data="bb",
                     output_parse_method=detection_box,
                     train=False)
print("Number of testing samples: {}".format(len(test_ds)))

train_data = gluon.data.DataLoader(train_ds.transform(augment_transform),
                                   batch_size,
                                   shuffle=True,
                                   last_batch="rollover",
                                   num_workers=8)
test_data = gluon.data.DataLoader(test_ds.transform(transform),
                                  batch_size,
                                  shuffle=False,
                                  last_batch="keep",
                                  num_workers=8)

#%%

net = SSD(num_classes=2, ctx=ctx)
net.hybridize()
log_dir = "./logs/handwriting_recognition"
checkpoint_dir = "model_checkpoint"
checkpoint_name = "handwriting.params"

#%%


train_ds = IAMDataset("line", output_data="text", train=True)
print("Number of training samples: {}".format(len(train_ds)))

test_ds = IAMDataset("line", output_data="text", train=False)
print("Number of testing samples: {}".format(len(test_ds)))

#%%

train_data = gluon.data.DataLoader(train_ds.transform(augment_transform), batch_size, shuffle=True, last_batch="rollover", num_workers=4)
test_data = gluon.data.DataLoader(test_ds.transform(transform), batch_size, shuffle=True, last_batch="keep", num_workers=4)#, num_workers=multiprocessing.cpu_count()-2)

#%%


net = CNNBiLSTM(num_downsamples=num_downsamples, resnet_layer_id=resnet_layer_id , rnn_hidden_states=lstm_hidden_states, rnn_layers=lstm_layers, max_seq_len=max_seq_len, ctx=ctx)
net.hybridize()

ctc_loss = gluon.loss.CTCLoss(weight=0.2)
best_test_loss = 10e5

if (os.path.isfile(os.path.join(checkpoint_dir, checkpoint_name))):
    net.load_parameters(os.path.join(checkpoint_dir, checkpoint_name))
    print("Parameters loaded")
    print(run_epoch(0, net, test_data, None, log_dir, print_name="pretrained", is_train=False))