def calculate_loss_vector(network, path, location_path, communicator): source = DataSource(path, opt.vocab_file, location_path, opt.seqlength, opt.batchsize) # the curr row -> the curr col # the curr col -> the next row row_loss = C.log(C.softmax(network['model'].outputs[0])) col_loss = C.log(C.softmax(network['model'].outputs[1])) loss = C.combine([row_loss, col_loss]) row_loss_vector = np.zeros((opt.vocabsize, vocab_sqrt)) col_loss_vector = np.zeros((opt.vocabsize, vocab_sqrt)) flag = True while flag: mb = source.next_minibatch(opt.seqlength * opt.batchsize * Communicator.num_workers(), Communicator.num_workers(), communicator.rank()) result = loss.eval({ network['row']: mb[source.input1], network['col']: mb[source.input2], }) row_prob = result[loss.outputs[0]] col_prob = result[loss.outputs[1]] label1 = mb[source.word1].asarray() label2 = mb[source.word2].asarray() sequences = len(label1) for i in range(sequences): seqlength = len(row_prob[i]) for j in range(seqlength): row_word = int(label1[i][j][0]) col_word = int(label2[i][j][0]) row_loss_vector[row_word] -= row_prob[i][j] col_loss_vector[col_word] -= col_prob[i][j] flag = not mb[source.input1].sweep_end return col_loss_vector, row_loss_vector
def evaluate(network, path, location_path): criterion = create_criterion(network) ce = criterion[0] source = DataSource(path, opt.vocab_file, location_path, opt.seqlength, opt.batchsize) error, tokens = 0, 0 flag = True while flag: mb = source.next_minibatch(opt.seqlength * opt.batchsize) loss = ce.eval({ network['row']: mb[source.input1], network['col']: mb[source.input2], network['row_label']: mb[source.label1], network['col_label']: mb[source.label2] }) error += sum([reduce(add, _)[0] for _ in loss]) tokens += mb[source.input1].num_samples flag = not mb[source.input1].sweep_end return error / tokens
def train(network, location_path, id): train_path = os.path.join(opt.datadir, opt.train_file) valid_path = os.path.join(opt.datadir, opt.valid_file) test_path = os.path.join(opt.datadir, opt.test_file) criterion = create_criterion(network) ce, pe = criterion[0], criterion[1] learner = create_learner(network['model']) learner = data_parallel_distributed_learner(learner) communicator = learner.communicator() trainer = C.Trainer(network['model'], (ce, pe), learner) # loop over epoch for epoch in range(opt.epochs[id]): source = DataSource(train_path, opt.vocab_file, location_path, opt.seqlength, opt.batchsize) loss, metric, tokens, batch_id = 0, 0, 0, 0 start_time = datetime.datetime.now() flag = True # loop over minibatch in the epoch while flag: mb = source.next_minibatch(opt.seqlength * opt.batchsize * Communicator.num_workers(), Communicator.num_workers(), communicator.rank()) trainer.train_minibatch({ network['row']: mb[source.input1], network['col']: mb[source.input2], network['row_label']: mb[source.label1], network['col_label']: mb[source.label2] }) samples = trainer.previous_minibatch_sample_count loss += trainer.previous_minibatch_loss_average * samples metric += trainer.previous_minibatch_evaluation_average * samples tokens += samples batch_id += 1 if Communicator.num_workers() > 1: communicator.barrier() if batch_id != 0 and batch_id % opt.freq == 0: diff_time = (datetime.datetime.now() - start_time) print("Epoch {:2}: Minibatch [{:5} - {:5}], loss = {:.6f}, error = {:.6f}, speed = {:3} tokens/s".format( epoch + 1, batch_id - opt.freq + 1, batch_id, loss / tokens, metric / tokens, tokens // diff_time.seconds)) flag = not mb[source.input1].sweep_end # Evaluation action if communicator.is_main(): valid_error = evaluate(network, valid_path, location_path) test_error = evaluate(network, test_path, location_path) print("Epoch {:2} Done : Valid error = {:.6f}, Test error = {:.6f}".format(epoch + 1, valid_error, test_error)) network['model'].save(os.path.join(opt.outputdir, 'round{}_epoch{}_'.format(id, epoch) + opt.save)) if Communicator.num_workers() > 1: communicator.barrier() # word allocate action row_loss, col_loss = calculate_loss_vector(network, train_path, location_path, communicator) if Communicator.num_workers() > 1: try: from mpi4py import MPI comm = MPI.COMM_WORLD if communicator.is_main(): for i in range(1, Communicator.num_workers()): row_loss_i, col_loss_i = comm.recv(source=i) row_loss += row_loss_i col_loss += col_loss_i else: data_send = [row_loss, col_loss] comm.send(data_send, 0) except: raise RuntimeError("Please install mpi4py if uses multi gpus!") communicator.barrier() if communicator.is_main(): allocate_table(row_loss, col_loss, opt.vocabsize, vocab_sqrt, opt.vocab_file, get_k_round_location_path(id + 1))