Beispiel #1
0
def calculate_loss_vector(network, path, location_path, communicator):
    source = DataSource(path, opt.vocab_file, location_path,
                        opt.seqlength, opt.batchsize)
    # the curr row -> the curr col
    # the curr col -> the next row
    row_loss = C.log(C.softmax(network['model'].outputs[0]))
    col_loss = C.log(C.softmax(network['model'].outputs[1]))
    loss = C.combine([row_loss, col_loss])
    row_loss_vector = np.zeros((opt.vocabsize, vocab_sqrt))
    col_loss_vector = np.zeros((opt.vocabsize, vocab_sqrt))

    flag = True
    while flag:
        mb = source.next_minibatch(opt.seqlength * opt.batchsize * Communicator.num_workers(),
                                   Communicator.num_workers(),
                                   communicator.rank())
        result = loss.eval({
            network['row']: mb[source.input1],
            network['col']: mb[source.input2],
        })
        row_prob = result[loss.outputs[0]]
        col_prob = result[loss.outputs[1]]
        label1 = mb[source.word1].asarray()
        label2 = mb[source.word2].asarray()
        sequences = len(label1)
        for i in range(sequences):
            seqlength = len(row_prob[i])
            for j in range(seqlength):
                row_word = int(label1[i][j][0])
                col_word = int(label2[i][j][0])
                row_loss_vector[row_word] -= row_prob[i][j]
                col_loss_vector[col_word] -= col_prob[i][j]
        flag = not mb[source.input1].sweep_end
    return col_loss_vector, row_loss_vector
Beispiel #2
0
def evaluate(network, path, location_path):
    criterion = create_criterion(network)
    ce = criterion[0]
    source = DataSource(path, opt.vocab_file, location_path,
                        opt.seqlength, opt.batchsize)
    error, tokens = 0, 0
    flag = True
    while flag:
        mb = source.next_minibatch(opt.seqlength * opt.batchsize)
        loss = ce.eval({
            network['row']: mb[source.input1],
            network['col']: mb[source.input2],
            network['row_label']: mb[source.label1],
            network['col_label']: mb[source.label2]
        })
        error += sum([reduce(add, _)[0] for _ in loss])
        tokens += mb[source.input1].num_samples
        flag = not mb[source.input1].sweep_end
    return error / tokens
Beispiel #3
0
def train(network, location_path, id):
    train_path = os.path.join(opt.datadir, opt.train_file)
    valid_path = os.path.join(opt.datadir, opt.valid_file)
    test_path = os.path.join(opt.datadir, opt.test_file)

    criterion = create_criterion(network)
    ce, pe = criterion[0], criterion[1]
    learner = create_learner(network['model'])
    learner = data_parallel_distributed_learner(learner)
    communicator = learner.communicator()
    trainer = C.Trainer(network['model'], (ce, pe), learner)

    # loop over epoch
    for epoch in range(opt.epochs[id]):
        source = DataSource(train_path, opt.vocab_file, location_path,
                            opt.seqlength, opt.batchsize)
        loss, metric, tokens, batch_id = 0, 0, 0, 0
        start_time = datetime.datetime.now()
        flag = True

        # loop over minibatch in the epoch
        while flag:
            mb = source.next_minibatch(opt.seqlength * opt.batchsize * Communicator.num_workers(),
                                       Communicator.num_workers(),
                                       communicator.rank())
            trainer.train_minibatch({
                network['row']: mb[source.input1],
                network['col']: mb[source.input2],
                network['row_label']: mb[source.label1],
                network['col_label']: mb[source.label2]
            })
            samples = trainer.previous_minibatch_sample_count
            loss += trainer.previous_minibatch_loss_average * samples
            metric += trainer.previous_minibatch_evaluation_average * samples
            tokens += samples
            batch_id += 1
            if Communicator.num_workers() > 1:
                communicator.barrier()
            if batch_id != 0 and batch_id % opt.freq == 0:
                diff_time = (datetime.datetime.now() - start_time)
                print("Epoch {:2}: Minibatch [{:5} - {:5}], loss = {:.6f}, error = {:.6f}, speed = {:3} tokens/s".format(
                    epoch + 1, batch_id - opt.freq + 1, batch_id,
                    loss / tokens, metric / tokens, tokens // diff_time.seconds))
            flag = not mb[source.input1].sweep_end

        # Evaluation action
        if communicator.is_main():
            valid_error = evaluate(network, valid_path, location_path)
            test_error = evaluate(network, test_path, location_path)
            print("Epoch {:2} Done : Valid error = {:.6f}, Test error = {:.6f}".format(epoch + 1, valid_error, test_error))
            network['model'].save(os.path.join(opt.outputdir, 'round{}_epoch{}_'.format(id, epoch) + opt.save))
        if Communicator.num_workers() > 1:
            communicator.barrier()
    # word allocate action
    row_loss, col_loss = calculate_loss_vector(network, train_path, location_path, communicator)
    if Communicator.num_workers() > 1:
        try:
            from mpi4py import MPI
            comm = MPI.COMM_WORLD
            if communicator.is_main():
                for i in range(1, Communicator.num_workers()):
                    row_loss_i, col_loss_i = comm.recv(source=i)
                    row_loss += row_loss_i
                    col_loss += col_loss_i
            else:
                data_send = [row_loss, col_loss]
                comm.send(data_send, 0)
        except:
            raise RuntimeError("Please install mpi4py if uses multi gpus!")
        communicator.barrier()
    if communicator.is_main():
        allocate_table(row_loss, col_loss,
                       opt.vocabsize, vocab_sqrt, opt.vocab_file,
                       get_k_round_location_path(id + 1))