Ejemplo n.º 1
def train(save_dir='saved_weights',
    Trains the model.

    parser_name is the string prefix used for the filename where the parser is
    saved after every epoch

    # load dataset
    load_existing_dump = False
    print('Loading dataset for training')
    dataset = load_datasets(load_existing_dump)
    # HINT: Look in the ModelConfig class for the model's hyperparameters
    config = dataset.model_config

    print('Loading embeddings')
    word_embeddings, pos_embeddings, dep_embeddings = load_embeddings(config)
    # TODO: For Optional Task, add Twitter and Wikipedia embeddings (do this last)

    if False:
        # Switch to True if you want to print examples of feature types
        print('words: ', len(dataset.word2idx))
        print('examples: ', [(k, v)
                             for i, (k,
                                     v) in enumerate(dataset.word2idx.items())
                             if i < 30])
        print('POS-tags: ', len(dataset.pos2idx))
        print('dependencies: ', len(dataset.dep2idx))
        print("some hyperparameters")

    # load parser object (used for Task 2)
    parser = ParserModel(config, word_embeddings, pos_embeddings, dep_embeddings)

    # Uncomment the following parser for Task 3
    # parser = AnotherParserModel(config, word_embeddings, pos_embeddings, dep_embeddings)

    device = torch.device(
        "cuda") if torch.cuda.is_available() else torch.device("cpu")

    # set save_dir for model
    if not os.path.exists(save_dir):

    # create object for loss function
    loss_fn = F.cross_entropy

    # create object for an optimizer that updated the weights of our parser
    # model.  Be sure to set the learning rate based on the parameters!
    optimizer = optim.Adam(parser.parameters(), lr=config.lr)

    for epoch in range(1, num_epochs + 1):

        ###### Training #####

        # load training set in minibatches
        for i, (train_x, train_y) in enumerate(get_minibatches([dataset.train_inputs,
                                                                dataset.train_targets], \

            word_inputs_batch, pos_inputs_batch, dep_inputs_batch = train_x

            # Convert the numpy data to pytorch's tensor represetation.  They're
            # numpy objects initially.  NOTE: In general, when using Pytorch,
            # you want to send them to the device that will do the computation
            # (either a GPU or CPU).  You do this by saying "obj.to(device)"
            # where we've already created the device for you (see above where we
            # did this for the parser).  This ensures your data is running on
            # the processor you expect it to!
            word_inputs_batch = torch.from_numpy(np.array(word_inputs_batch)).to(device)
            pos_inputs_batch = torch.from_numpy(np.array(pos_inputs_batch)).to(device)
            dep_inputs_batch = torch.from_numpy(np.array(dep_inputs_batch)).to(device)

            # Convert the labels from 1-hot vectors to a list of which index was
            # 1, which is what Pytorch expects.  HINT: look for the "argmax"
            # function in numpy.
            labels = np.argmax(train_y, axis=1)

            # Convert the label to pytorch's tensor
            labels = torch.from_numpy(labels).to(device)

            # This is just a quick hack so you can cut training short to see how
            # things are working.  In the final model, make sure to use all the data!
            if max_iters >= 0 and i > max_iters:

            # Some debugging information for you
            if i == 0 and epoch == 1:
                print("size of word inputs: ", word_inputs_batch.size())
                print("size of pos inputs: ", pos_inputs_batch.size())
                print("size of dep inputs: ", dep_inputs_batch.size())
                print("size of labels: ", labels.size())

            #### Backprop & Update weights ####

            # Before the backward pass, use the optimizer object to zero all of
            # the gradients for the variables

            # For the current batch of inputs, run a full forward pass through the
            # data and get the outputs for each item's prediction.
            # These are the raw outputs, which represent the activations for
            # prediction over valid transitions.
            outputs = parser.forward(word_inputs_batch, pos_inputs_batch, dep_inputs_batch)

            # Compute the loss for the outputs with the labels.  Note that for
            # your particular loss (cross-entropy) it will compute the softmax
            # for you, so you can safely pass in the raw activations.
            loss = loss_fn(outputs, labels)

            # Backward pass: compute gradient of the loss with respect to model parameters

            # Perform 1 update using the optimizer

            # Every 10 batches, print out some reporting so we can see convergence
            if i % print_every_iters == 0:
                print ('Epoch: %d [%d], loss: %1.3f, acc: %1.3f' \
                       % (epoch, i, loss.item(),

        print("End of epoch")

        # save model
        save_file = os.path.join(save_dir, '%s-epoch-%d.mdl' % (parser_name,
        print('Saving current state of model to %s' % save_file)
        torch.save(parser, save_file)

        ###### Validation #####
        print('Evaluating on valudation data after epoch %d' % epoch)

        # Once we're in test/validation time, we need to indicate that we are in
        # "evaluation" mode.  This will turn off things like Dropout so that
        # we're not randomly zero-ing out weights when it might hurt performance

        # Compute the current model's UAS score on the validation (development)
        # dataset.  Note that we can use this held-out data to tune the
        # hyper-parameters of the model but we should never look at the test
        # data until we want to report the very final result.
        compute_dependencies(parser, device, dataset.valid_data, dataset)
        valid_UAS = get_UAS(dataset.valid_data)
        print("- validation UAS: {:.2f}".format(valid_UAS * 100.0))

        # Once we're done with test/validation, we need to indicate that we are back in
        # "train" mode.  This will turn back on things like Dropout

    return parser