示例#1
0
def main():
    '''read arguments'''
    parser = build_parser()
    args = parser.parse_args()
    config = args

    print("Loading Data!")
    train_corpus, val_corpus_bins = load_data(config, num_bins=config.bins)
    data_dir = os.path.join('data', config.dataset)

    if os.path.exists(data_dir) == False:
        os.mkdir(data_dir)

    print("Writing Train corpus")
    with open(os.path.join(data_dir, 'train_corpus.pk'), 'wb') as f:
        pickle.dump(file=f, obj=train_corpus)
    print("Done")

    print("Writing Val corpus bins")
    with open(os.path.join(data_dir, 'val_corpus_bins.pk'), 'wb') as f:
        pickle.dump(file=f, obj=val_corpus_bins)
    print("Done")

    print("Writing Train text files")
    with open(os.path.join(data_dir, 'train_src.txt'), 'w') as f:
        f.write('\n'.join(train_corpus.source))

    with open(os.path.join(data_dir, 'train_tgt.txt'), 'w') as f:
        f.write('\n'.join(train_corpus.target))
    print("Done")

    print("Writing Val text files")
    for i, val_corpus_bin in enumerate(val_corpus_bins):
        with open(os.path.join(data_dir, 'val_src_bin{}.txt'.format(i)),
                  'w') as f:
            f.write('\n'.join(val_corpus_bin.source))

        with open(os.path.join(data_dir, 'val_tgt_bin{}.txt'.format(i)),
                  'w') as f:
            f.write('\n'.join(val_corpus_bin.target))
    print("Done")

    print("Gathering Length and Depth info of the dataset")
    train_depths = list(
        set([
            train_corpus.Lang.depth_counter(line).sum(1).max()
            for line in train_corpus.source
        ]))
    train_lens = list(set([len(line) for line in train_corpus.source]))

    val_lens_bins, val_depths_bins = [], []
    for i, val_corpus in enumerate(val_corpus_bins):
        val_depths = list(
            set([
                val_corpus.Lang.depth_counter(line).sum(1).max()
                for line in val_corpus.source
            ]))
        val_depths_bins.append(val_depths)

        val_lens = list(set([len(line) for line in val_corpus.source]))
        val_lens_bins.append(val_lens)

    info_dict = {}
    info_dict['Lang'] = '{}-{}'.format(config.lang, config.num_par)
    info_dict['Train Lengths'] = (min(train_lens), max(train_lens))
    info_dict['Train Depths'] = (int(min(train_depths)),
                                 int(max(train_depths)))
    info_dict['Train Size'] = len(train_corpus.source)

    for i, (val_lens,
            val_depths) in enumerate(zip(val_lens_bins, val_depths_bins)):
        info_dict['Val Bin-{} Lengths'.format(i)] = (min(val_lens),
                                                     max(val_lens))
        info_dict['Val Bin-{} Depths'.format(i)] = (int(min(val_depths)),
                                                    int(max(val_depths)))
        info_dict['Val Bin-{} Size'.format(i)] = len(val_corpus_bins[i].source)

    with open(os.path.join('data', config.dataset, 'data_info.json'),
              'w') as f:
        json.dump(obj=info_dict, fp=f)

    print("Done")
示例#2
0
def main():

    # Parse arguments
    parser = build_parser()
    args = parser.parse_args()
    args.mode = args.mode.lower()

    # div_gps = args.div_gps
    # div_lam = args.div_lam
    if args.mode == 'train':
        if len(args.run_name.split()) == 0:
            args.run_name = datetime.fromtimestamp(time.time()).strftime(
                args.date_fmt)
    else:
        args.run_name = args.run_name

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    random.seed(args.seed)

    smethod = str(args.selec)
    data_sub = str(os.path.join('data', args.dataset, 'test',
                                'src.txt')).split('/')[-1].split('.')[0]
    slam = args.slam

    a1 = args.a1
    a2 = args.a2
    b1 = args.b1
    b2 = args.b2
    sparam = [a1, a2, b1, b2]

    outdir = str(args.out_dir)

    # GPU initialization
    device = gpu_init_pytorch(args.gpu)

    log_folder_name = os.path.join('Logs', args.run_name)
    create_save_directories('Logs', 'Model', args.run_name)
    logger = get_logger(__name__, args.run_name, args.log_fmt, logging.INFO,
                        os.path.join(log_folder_name, 's2s.log'))

    if args.mode == 'train':
        train_dataloader, val_dataloader = read_files(args, logger)
        logger.info('Creating vocab ...')

        voc = Voc(args.dataset)
        voc = create_vocab_dict(args, voc, train_dataloader)

        logger.info('Vocab created with number of words = {}'.format(
            voc.nwords))
        logger.info('Saving Vocabulary file')

        with open(os.path.join('Model', args.run_name, 'vocab.p'), 'wb') as f:
            pickle.dump(voc, f, protocol=pickle.HIGHEST_PROTOCOL)

        logger.info('Vocabulary file saved in {}'.format(
            os.path.join('Model', args.run_name, 'vocab.p')))
    else:
        test_dataloader = read_files(args, logger)
        logger.info('Loading Vocabulary file')

        with open(os.path.join('Model', args.run_name, 'vocab.p'), 'rb') as f:
            voc = pickle.load(f)

        logger.info('Vocabulary file Loaded from {}'.format(
            os.path.join('Model', args.run_name, 'vocab.p')))

    # Get Checkpoint, return None if no checkpoint present
    checkpoint = get_latest_checkpoint('Model', args.run_name, logger)

    if args.mode == 'train':
        if checkpoint == None:
            logger.info('Starting a fresh training procedure')
            ep_offset = 0
            min_val_loss = 1e8
            max_val_bleu = 0.0
            config_file_name = os.path.join('Model', args.run_name, 'config.p')

            if args.use_word2vec:
                args.emb_size = 300

            model = s2s(args, voc, device, logger)

            with open(config_file_name, 'wb') as f:
                pickle.dump(vars(args), f, protocol=pickle.HIGHEST_PROTOCOL)
        else:
            config_file_name = os.path.join('Model', args.run_name, 'config.p')

            with open(config_file_name, 'rb') as f:
                args = AttrDict(pickle.load(f))

            if args.use_word2vec:
                args.emb_size = 300

            model = s2s(args, voc, device, logger)

            ep_offset, train_loss, min_val_loss, max_val_bleu, voc = load_checkpoint(
                model, args.mode, checkpoint, logger, device)

            logger.info('Resuming Training From ')
            od = OrderedDict()
            od['Epoch'] = ep_offset
            od['Train_loss'] = train_loss
            od['Validation_loss'] = min_val_loss
            od['Validation_Bleu'] = max_val_bleu
            print_log(logger, od)
            ep_offset += 1

        # Call Training function
        train(model, train_dataloader, val_dataloader, voc, device, args,
              logger, ep_offset, min_val_loss, max_val_bleu)
    else:
        if checkpoint == None:
            logger.info('Cannot decode because of absence of checkpoints')
            sys.exit()
        else:
            config_file_name = os.path.join('Model', args.run_name, 'config.p')
            beam_width = args.beam_width
            gpu = args.gpu

            with open(config_file_name, 'rb') as f:
                args = AttrDict(pickle.load(f))
                args.beam_width = beam_width
                args.gpu = gpu
                # args.div_beam = div_beam
                # args.div_gps = div_gps
                # args.div_lam = div_lam

            if args.use_word2vec:
                args.emb_size = 300

            args.slam = slam
            args.sparam = sparam
            args.out_dir = outdir

            model = s2s(args, voc, device, logger)

            ep_offset, train_loss, min_val_loss, max_val_bleu, voc = load_checkpoint(
                model, args.mode, checkpoint, logger, device)

            logger.info('Decoding from')
            od = OrderedDict()
            od['Epoch'] = ep_offset
            od['Train_Loss'] = train_loss
            od['Validation_Loss'] = min_val_loss
            od['Validation_Bleu'] = max_val_bleu
            print_log(logger, od)

        if args.beam_width == 1:
            decode_greedy(model, test_dataloader, voc, device, args, logger)
        else:
            decode_beam(model, test_dataloader, voc, device, args, logger,
                        smethod, data_sub)
def main():
	'''read arguments'''
	parser = build_parser()
	args = parser.parse_args()
	config =args
	mode = config.mode
	if mode == 'train':
		is_train = True
	else:
		is_train = False

	''' Set seed for reproducibility'''
	np.random.seed(config.seed)
	torch.manual_seed(config.seed)
	random.seed(config.seed)

	'''GPU initialization'''
	device = gpu_init_pytorch(config.gpu)
	#device = 'cpu'
	'''Run Config files/paths'''
	run_name = config.run_name
	config.log_path = os.path.join(log_folder, run_name)
	config.model_path = os.path.join(model_folder, run_name)
	config.board_path = os.path.join(board_path, run_name)

	vocab_path = os.path.join(config.model_path, 'vocab.p')
	config_file = os.path.join(config.model_path, 'config.p')
	log_file = os.path.join(config.log_path, 'log.txt')

	if config.results:
		config.result_path = os.path.join(result_folder, 'val_results_{}.json'.format(config.dataset))

	if is_train:
		create_save_directories(config.log_path, config.model_path)
	else:
		create_save_directories(config.log_path, config.result_path)

	logger = get_logger(run_name, log_file, logging.DEBUG)
	writer = SummaryWriter(config.board_path)

	logger.debug('Created Relevant Directories')
	logger.info('Experiment Name: {}'.format(config.run_name))

	'''Read Files and create/load Vocab'''
	if is_train:

		logger.debug('Creating Vocab and loading Data ...')
		train_loader, val_loader_bins, voc  = load_data(config, logger)

		logger.info(
			'Vocab Created with number of words : {}'.format(voc.nwords))		

		with open(vocab_path, 'wb') as f:
			pickle.dump(voc, f, protocol=pickle.HIGHEST_PROTOCOL)
		logger.info('Vocab saved at {}'.format(vocab_path))



	else:
		logger.info('Loading Vocab File...')

		with open(vocab_path, 'rb') as f:
			voc = pickle.load(f)

		logger.info('Vocab Files loaded from {}'.format(vocab_path))

		logger.info("Loading Test Dataloaders...")
		config.batch_size = 1
		test_loader_bins = load_data(config, logger, voc)
		logger.info("Done loading test dataloaders")

	# print('Done')

	# TO DO : Load Existing Checkpoints here


	if is_train:
		
		max_val_acc = 0.0
		epoch_offset= 0


		if config.load_model:
			checkpoint = get_latest_checkpoint(config.model_path, logger)
			if checkpoint:
				ckpt = torch.load(checkpoint, map_location=lambda storage, loc: storage)
				#config.lr = checkpoint['lr']
				model = build_model(config=config, voc=voc, device=device, logger=logger)
				model.load_state_dict(ckpt['model_state_dict'])
				model.optimizer.load_state_dict(ckpt['optimizer_state_dict'])
		else:
			model = build_model(config=config, voc=voc, device=device, logger=logger)
		# pdb.set_trace()

		logger.info('Initialized Model')

		with open(config_file, 'wb') as f:
			pickle.dump(vars(config), f, protocol=pickle.HIGHEST_PROTOCOL)

		logger.debug('Config File Saved')

		logger.info('Starting Training Procedure')
		train_model(model, train_loader, val_loader_bins, voc,
					device, config, logger, epoch_offset, max_val_acc, writer)

	else:

		gpu = config.gpu

		with open(config_file, 'rb') as f:
			bias = config.bias
			extraffn = config.extraffn
			config = AttrDict(pickle.load(f))
			config.gpu = gpu
			config.bins = len(test_loader_bins)
			config.batch_size = 1
			config.bias = bias
			config.extraffn = extraffn
			# To do: remove it later
			#config.num_labels =2  

		model = build_model(config=config, voc=voc, device=device, logger=logger)
		checkpoint = get_latest_checkpoint(config.model_path, logger)
		ep_offset, train_loss, score, voc = load_checkpoint(
			model, config.mode, checkpoint, logger, device, bins = config.bins)

		logger.info('Prediction from')
		od = OrderedDict()
		od['epoch'] = ep_offset
		od['train_loss'] = train_loss
		if config.bins != -1:
			for i in range(config.bins):
				od['max_val_acc_bin{}'.format(i)] = score[i]
		else:
			od['max_val_acc'] = score
		print_log(logger, od)
		pdb.set_trace()
		#test_acc_epoch, test_loss_epoch = run_validation(config, model, test_loader, voc, device, logger)
		#test_analysis_dfs = []
		for i in range(config.bins):
			test_acc_epoch, test_analysis_df = run_test(config, model, test_loader_bins[i], voc, device, logger)
			logger.info('Bin {} Accuracy: {}'.format(i, test_acc_epoch))
			#test_analysis_dfs.append(test_analysis_df)
			test_analysis_df.to_csv(os.path.join(result_folder, '{}_{}_test_analysis_bin{}.csv'.format(config.dataset, config.model_type, i)))
		logger.info("Analysis results written to {}...".format(result_folder))
示例#4
0
def main():
    '''Parse Arguments'''
    parser = build_parser()
    args = parser.parse_args()
    '''Specify Seeds for reproducibility'''
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    '''Configs'''
    device = gpu_init_pytorch(args.gpu)

    mode = args.mode
    if mode == 'train':
        is_train = True
    else:
        is_train = False

    # ckpt= args.ckpt

    run_name = args.run_name
    args.log_path = os.path.join(log_folder, run_name)
    args.model_path = os.path.join(model_folder, run_name)
    args.board_path = os.path.join(board_path, run_name)
    args.outputs_path = os.path.join(outputs_folder, run_name)

    args_file = os.path.join(args.model_path, 'args.p')

    log_file = os.path.join(args.log_path, 'log.txt')

    if args.results:
        args.result_path = os.path.join(
            result_folder, 'val_results_{}.json'.format(args.dataset))

    logging_var = bool(args.logging)

    if is_train:
        create_save_directories(args.log_path)
        create_save_directories(args.model_path)
        create_save_directories(args.outputs_path)
    else:
        create_save_directories(args.log_path)
        create_save_directories(args.result_path)

    logger = get_logger(run_name, log_file, logging.DEBUG)

    logger.debug('Created Relevant Directories')
    logger.info('Experiment Name: {}'.format(args.run_name))

    if args.mt:

        vocab1_path = os.path.join(args.model_path, 'vocab1.p')
        vocab2_path = os.path.join(args.model_path, 'vocab2.p')

        if is_train:
            #pdb.set_trace()
            train_dataloader, val_dataloader = load_data(args, logger)

            logger.debug('Creating Vocab...')

            voc1 = Voc()
            voc1.create_vocab_dict(args, 'src', train_dataloader)

            # To Do : Remove Later
            voc1.add_to_vocab_dict(args, 'src', val_dataloader)

            voc2 = Voc()
            voc2.create_vocab_dict(args, 'trg', train_dataloader)

            # To Do : Remove Later
            voc2.add_to_vocab_dict(args, 'trg', val_dataloader)
            logger.info('Vocab Created with number of words : {}'.format(
                voc1.nwords))

            with open(vocab1_path, 'wb') as f:
                pickle.dump(voc1, f, protocol=pickle.HIGHEST_PROTOCOL)
            with open(vocab2_path, 'wb') as f:
                pickle.dump(voc2, f, protocol=pickle.HIGHEST_PROTOCOL)
            logger.info('Vocab saved at {}'.format(vocab1_path))

        else:
            test_dataloader = load_data(args, logger)
            logger.info('Loading Vocab File...')

            with open(vocab1_path, 'rb') as f:
                voc1 = pickle.load(f)
            with open(vocab2_path, 'rb') as f:
                voc2 = pickle.load(f)
            logger.info(
                'Vocab Files loaded from {}\nNumber of Words: {}'.format(
                    vocab1_path, voc1.nwords))

            # print('Done')

            # TO DO : Load Existing Checkpoints here
        checkpoint = get_latest_checkpoint(args.model_path, logger)
        '''Param Specs'''
        layers = args.layers
        heads = args.heads
        d_model = args.d_model
        d_ff = args.d_ff
        max_len = args.max_length
        dropout = args.dropout
        BATCH_SIZE = args.batch_size
        epochs = args.epochs

        if logging_var:
            meta_fname = os.path.join(args.log_path, 'meta.txt')
            loss_fname = os.path.join(args.log_path, 'loss.txt')

            meta_fh = open(meta_fname, 'w')
            loss_fh = open(loss_fname, 'w')

            print('Log Files created at: {}'.format(args.log_path))

            write_meta(args, meta_fh)
        """stime= time.time()
					print('Loading Data...')
					train, val, test, SRC, TGT = build_data()
					etime= (time.time()-stime)/60
					print('Data Loaded\nTime Taken:{}'.format(etime ))"""

        pad_idx = voc1.w2id['PAD']

        model = make_model(voc1.nwords,
                           voc2.nwords,
                           N=layers,
                           h=heads,
                           d_model=d_model,
                           d_ff=d_ff,
                           dropout=dropout)
        model.to(device)

        criterion = LabelSmoothing(size=voc2.nwords,
                                   padding_idx=pad_idx,
                                   smoothing=0.1)
        criterion.to(device)

        # train_iter = MyIterator(train, batch_size=BATCH_SIZE, device=device,
        # 						repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)),
        # 						batch_size_fn=batch_size_fn, train=True)

        # valid_iter = MyIterator(val, batch_size=BATCH_SIZE, device=device,
        # 						repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)),
        # 						batch_size_fn=batch_size_fn, train=False)

        if mode == 'train':
            model_opt = NoamOpt(
                model.src_embed[0].d_model, 1, 2000,
                torch.optim.Adam(model.parameters(),
                                 lr=0,
                                 betas=(0.9, 0.98),
                                 eps=1e-9))
            max_val_score = 0.0
            min_error_score = 100.0
            epoch_offset = 0
            for epoch in range(epochs):
                # pdb.set_trace()
                #if epoch%3==0:

                print('Training Epoch: ', epoch)
                model.train()
                run_epoch((rebatch(args, device, voc1, voc2, pad_idx, b)
                           for b in train_dataloader), model,
                          LossCompute(model.generator,
                                      criterion,
                                      device=device,
                                      opt=model_opt))
                model.eval()
                # loss = run_epoch((rebatch(args, device, voc1, voc2, pad_idx, b) for b in val_dataloader),
                #  				  model,
                #  				  LossCompute(model.generator, criterion, device=device, opt=None))
                # loss_str= "Epoch: {} \t Val Loss: {}\n".format(epoch,loss)
                # print(loss_str)

                refs = []
                hyps = []
                error_score = 0

                for i, batch in enumerate(val_dataloader):
                    sent1s = sents_to_idx(voc1, batch['src'], args.max_length)
                    sent2s = sents_to_idx(voc2, batch['trg'], args.max_length)
                    sent1_var, sent2_var, input_len1, input_len2 = process_batch(
                        sent1s, sent2s, voc1, voc2, device, voc1.id2w[pad_idx])

                    sent1s = idx_to_sents(voc1, sent1_var, no_eos=True)
                    sent2s = idx_to_sents(voc2, sent2_var, no_eos=True)

                    #pdb.set_trace()
                    # for l in range(len(batch['src'])):
                    # 	if len(batch['src'][l].split())!=9:
                    # 		print(l)

                    #for eg in range(sent1_var.size(0)):
                    src = sent1_var.transpose(0, 1)
                    src_mask = (src != voc1.w2id['PAD']).unsqueeze(-2)

                    #refs.append([' '.join(sent2s[eg])])
                    refs += [[' '.join(sent2s[i])]
                             for i in range(sent2_var.size(1))]

                    # pdb.set_trace()
                    out = greedy_decode(model,
                                        src,
                                        src_mask,
                                        max_len=60,
                                        start_symbol=voc2.w2id['<s>'],
                                        pad=pad_idx)

                    words = []

                    decoded_words = [[] for i in range(out.size(0))]
                    ends = []

                    #pdb.set_trace()

                    #print("Translation:", end="\t")
                    for z in range(1, out.size(1)):
                        for b in range(len(decoded_words)):
                            sym = voc2.id2w[out[b, z].item()]
                            if b not in ends:
                                if sym == "</s>":
                                    ends.append(b)
                                    continue
                                #print(sym, end =" ")
                                decoded_words[b].append(sym)

                    with open(args.outputs_path + '/outputs.txt',
                              'a') as f_out:
                        f_out.write('Batch: ' + str(i) + '\n')
                        f_out.write(
                            '---------------------------------------\n')
                        for z in range(len(decoded_words)):
                            try:
                                f_out.write('Example: ' + str(z) + '\n')
                                f_out.write('Source: ' + batch['src'][z] +
                                            '\n')
                                f_out.write('Target: ' + batch['trg'][z] +
                                            '\n')
                                f_out.write('Generated: ' +
                                            stack_to_string(decoded_words[z]) +
                                            '\n' + '\n')
                            except:
                                logger.warning('Exception: Failed to generate')
                                pdb.set_trace()
                                break
                        f_out.write(
                            '---------------------------------------\n')
                        f_out.close()

                    hyps += [
                        ' '.join(decoded_words[z])
                        for z in range(len(decoded_words))
                    ]
                    #hyps.append(stack_to_string(words))

                    error_score += cal_score(decoded_words, batch['trg'])

                    #print()
                    #print("Target:", end="\t")
                    for z in range(1, sent2_var.size(0)):
                        sym = voc2.id2w[sent2_var[z, 0].item()]
                        if sym == "</s>": break
                        #print(sym, end =" ")
                    #print()
                    #break

                val_bleu_epoch = bleu_scorer(refs, hyps)
                print('Epoch: {}  Val bleu: {}'.format(epoch,
                                                       val_bleu_epoch[0]))
                print('Epoch: {}  Val Error: {}'.format(
                    epoch, error_score / len(val_dataloader)))

                # if logging_var:
                # 	loss_fh.write(loss_str)
                if epoch % 10 == 0:
                    ckpt_path = os.path.join(args.model_path, 'model.pt')
                    logger.info('Saving Checkpoint at : {}'.format(ckpt_path))
                    torch.save(model.state_dict(), ckpt_path)
                    print('Model saved at: {}'.format(ckpt_path))

        else:
            model.load_state_dict(torch.load(args.model_path))
            model.eval()

        # pdb.set_trace()
        # for i, batch in enumerate(val_dataloader):
        # 	sent1s = sents_to_idx(voc1, batch['src'], args.max_length)
        # 	sent2s = sents_to_idx(voc2, batch['trg'], args.max_length)
        # 	sent1_var, sent2_var, input_len1, input_len2  = process_batch(sent1s, sent2s, voc1, voc2, device)
        # 	src = sent1_var.transpose(0, 1)[:1]
        # 	src_mask = (src != voc1.w2id['PAD']).unsqueeze(-2)
        # 	out = greedy_decode(model, src, src_mask, max_len=max_len, start_symbol=voc2.w2id['<s>'])
        # 	print("Translation:", end="\t")
        # 	for i in range(1, out.size(1)):
        # 		sym = voc2.id2w[out[0, i].item()]
        # 		if sym == "</s>": break
        # 		print(sym, end =" ")
        # 	print()
        # 	print("Target:", end="\t")
        # 	for i in range(1, sent2_var.size(0)):
        # 		sym = voc2.id2w[sent2_var[i, 0].item()]
        # 		if sym == "</s>": break
        # 		print(sym, end =" ")
        # 	print()
        # 	break

    else:
        '''
		Code for Synthetic Data
		'''
        vocab_path = os.path.join(args.model_path, 'vocab.p')

        if is_train:
            #pdb.set_trace()
            train_dataloader, val_dataloader = load_data(args, logger)

            logger.debug('Creating Vocab...')

            voc = Syn_Voc()
            voc.create_vocab_dict(args, train_dataloader)

            # To Do : Remove Later
            voc.add_to_vocab_dict(args, val_dataloader)

            logger.info('Vocab Created with number of words : {}'.format(
                voc.nwords))

            with open(vocab_path, 'wb') as f:
                pickle.dump(voc, f, protocol=pickle.HIGHEST_PROTOCOL)

            logger.info('Vocab saved at {}'.format(vocab_path))

        else:
            test_dataloader = load_data(args, logger)
            logger.info('Loading Vocab File...')

            with open(vocab_path, 'rb') as f:
                voc = pickle.load(f)

            logger.info(
                'Vocab Files loaded from {}\nNumber of Words: {}'.format(
                    vocab_path, voc.nwords))

            # print('Done')

            # TO DO : Load Existing Checkpoints here
        # checkpoint = get_latest_checkpoint(args.model_path, logger)
        '''Param Specs'''
        layers = args.layers
        heads = args.heads
        d_model = args.d_model
        d_ff = args.d_ff
        max_len = args.max_length
        dropout = args.dropout
        BATCH_SIZE = args.batch_size
        epochs = args.epochs

        if logging_var:
            meta_fname = os.path.join(args.log_path, 'meta.txt')
            loss_fname = os.path.join(args.log_path, 'loss.txt')

            meta_fh = open(meta_fname, 'w')
            loss_fh = open(loss_fname, 'w')

            print('Log Files created at: {}'.format(args.log_path))

            write_meta(args, meta_fh)
        """stime= time.time()
					print('Loading Data...')
					train, val, test, SRC, TGT = build_data()
					etime= (time.time()-stime)/60
					print('Data Loaded\nTime Taken:{}'.format(etime ))"""

        pad_idx = voc.w2id['PAD']

        model = make_model(voc.nwords,
                           voc.nwords,
                           N=layers,
                           h=heads,
                           d_model=d_model,
                           d_ff=d_ff,
                           dropout=dropout)
        model.to(device)

        logger.info('Initialized Model')

        criterion = LabelSmoothing(size=voc.nwords,
                                   padding_idx=pad_idx,
                                   smoothing=0.1)
        criterion.to(device)

        # train_iter = MyIterator(train, batch_size=BATCH_SIZE, device=device,
        # 						repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)),
        # 						batch_size_fn=batch_size_fn, train=True)

        # valid_iter = MyIterator(val, batch_size=BATCH_SIZE, device=device,
        # 						repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)),
        # 						batch_size_fn=batch_size_fn, train=False)

        if mode == 'train':
            model_opt = NoamOpt(
                model.src_embed[0].d_model, 1, 3000,
                torch.optim.Adam(model.parameters(),
                                 lr=0,
                                 betas=(0.9, 0.98),
                                 eps=1e-9))
            max_bleu_score = 0.0
            min_error_score = 100.0
            epoch_offset = 0
            logger.info('Starting Training Procedure')
            for epoch in range(epochs):
                # pdb.set_trace()
                #if epoch%3==0:

                print('Training Epoch: ', epoch)
                model.train()
                start_time = time.time()
                run_epoch((rebatch(args, device, voc, voc, pad_idx, b)
                           for b in train_dataloader), model,
                          LossCompute(model.generator,
                                      criterion,
                                      device=device,
                                      opt=model_opt))

                time_taken = (time.time() - start_time) / 60.0
                logger.debug(
                    'Training for epoch {} completed...\nTime Taken: {}'.
                    format(epoch, time_taken))
                logger.debug('Starting Validation')

                model.eval()
                # loss = run_epoch((rebatch(args, device, voc1, voc2, pad_idx, b) for b in val_dataloader),
                #  				  model,
                #  				  LossCompute(model.generator, criterion, device=device, opt=None))
                # loss_str= "Epoch: {} \t Val Loss: {}\n".format(epoch,loss)
                # print(loss_str)

                refs = []
                hyps = []
                error_score = 0

                for i, batch in enumerate(val_dataloader):
                    sent1s = sents_to_idx(voc, batch['src'], args.max_length)
                    sent2s = sents_to_idx(voc, batch['trg'], args.max_length)
                    sent1_var, sent2_var, input_len1, input_len2 = process_batch(
                        sent1s, sent2s, voc, voc, device, voc.id2w[pad_idx])

                    sent1s = idx_to_sents(voc, sent1_var, no_eos=True)
                    sent2s = idx_to_sents(voc, sent2_var, no_eos=True)

                    # pdb.set_trace()
                    # for l in range(len(batch['src'])):
                    # 	if len(batch['src'][l].split())!=9:
                    # 		print(l)

                    #for eg in range(sent1_var.size(0)):
                    src = sent1_var.transpose(0, 1)

                    ### FOR NON-DIRECTIONAL ###
                    # src_mask = (src != voc.w2id['PAD']).unsqueeze(-2)

                    ### FOR DIRECTIONAL ###
                    src_mask = make_std_mask(src, pad_idx)
                    src_mask_bi = make_bi_std_mask(src, pad_idx)
                    src_mask_dec = (src != voc.w2id['PAD']).unsqueeze(-2)
                    #refs.append([' '.join(sent2s[eg])])
                    # refs += [[' '.join(sent2s[i])] for i in range(sent2_var.size(1))]
                    refs += [[x] for x in batch['trg']]

                    out = greedy_decode(model,
                                        src,
                                        src_mask,
                                        max_len=max_len,
                                        start_symbol=voc.w2id['<s>'],
                                        pad=pad_idx,
                                        src_mask_dec=src_mask_dec,
                                        src_mask_bi=src_mask_bi)

                    words = []

                    decoded_words = [[] for i in range(out.size(0))]
                    ends = []

                    # pdb.set_trace()

                    #print("Translation:", end="\t")
                    for z in range(1, out.size(1)):
                        for b in range(len(decoded_words)):
                            sym = voc.id2w[out[b, z].item()]
                            if b not in ends:
                                if sym == "</s>":
                                    ends.append(b)
                                    continue
                                #print(sym, end =" ")
                                decoded_words[b].append(sym)

                    with open(args.outputs_path + '/outputs.txt',
                              'a') as f_out:
                        f_out.write('Batch: ' + str(i) + '\n')
                        f_out.write(
                            '---------------------------------------\n')
                        for z in range(len(decoded_words)):
                            try:
                                f_out.write('Example: ' + str(z) + '\n')
                                f_out.write('Source: ' + batch['src'][z] +
                                            '\n')
                                f_out.write('Target: ' + batch['trg'][z] +
                                            '\n')
                                f_out.write('Generated: ' +
                                            stack_to_string(decoded_words[z]) +
                                            '\n' + '\n')
                            except:
                                logger.warning('Exception: Failed to generate')
                                pdb.set_trace()
                                break
                        f_out.write(
                            '---------------------------------------\n')
                        f_out.close()

                    hyps += [
                        ' '.join(decoded_words[z])
                        for z in range(len(decoded_words))
                    ]
                    #hyps.append(stack_to_string(words))

                    if args.ap:
                        error_score += cal_score_AP(decoded_words,
                                                    batch['trg'])
                    else:
                        error_score += cal_score(decoded_words, batch['trg'])

                    #print()
                    #print("Target:", end="\t")
                    for z in range(1, sent2_var.size(0)):
                        sym = voc.id2w[sent2_var[z, 0].item()]
                        if sym == "</s>": break
                        #print(sym, end =" ")
                    #print()
                    #break

                if (error_score / len(val_dataloader)) < min_error_score:
                    min_error_score = error_score / len(val_dataloader)

                val_bleu_epoch = bleu_scorer(refs, hyps)

                if max_bleu_score < val_bleu_epoch[0]:
                    max_bleu_score = val_bleu_epoch[0]

                logger.info('Epoch: {}  Val bleu: {}'.format(
                    epoch, val_bleu_epoch[0]))
                logger.info('Maximum Bleu: {}'.format(max_bleu_score))
                logger.info('Epoch: {}  Val Error: {}'.format(
                    epoch, error_score / len(val_dataloader)))
                logger.info('Minimum Error: {}'.format(min_error_score))

                # if logging_var:
                # 	loss_fh.write(loss_str)
                if epoch % 5 == 0:
                    ckpt_path = os.path.join(args.model_path, 'model.pt')
                    logger.info('Saving Checkpoint at : {}'.format(ckpt_path))
                    torch.save(model.state_dict(), ckpt_path)
                    print('Model saved at: {}'.format(ckpt_path))

            store_results(args, max_bleu_score, min_error_score)
            logger.info('Scores saved at {}'.format(args.result_path))

        else:
            model.load_state_dict(torch.load(args.model_path))
            model.eval()
示例#5
0
def main():
    parser = build_parser()
    args = parser.parse_args()
    args.tree_height = [int(s) for s in args.tree_height.split(",")]
    use_ptr = args.use_ptr
    cov_after_ep = args.cov_after_ep
    height_dec = args.height_dec

    if args.mode == 'train':
        if len(args.run_name.split()) == 0:
            args.run_name = datetime.fromtimestamp(time.time()).strftime(
                args.date_fmt)

    if args.pretrained_encoder == None or args.pretrained_encoder == 'bert_all':
        args.use_attn = True

    # SET SEEDS
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    random.seed(args.seed)

    # ASSIGN GPU
    device = gpu_init_pytorch(args.gpu)
    # CREATE LOGGING FOLDER
    log_folder_name = os.path.join('Logs', args.run_name)
    create_save_directories('Logs', 'Models', args.run_name)

    # Comet ML - Log all params
    experiment = None
    if not args.debug:
        experiment = Experiment(api_key=_API_KEY,
                                project_name=args.project_name,
                                workspace="NAN")
        experiment.set_name(args.run_name)
        experiment.log_parameters(vars(args))

    logger = get_logger(__name__, 'temp_run', args.log_fmt, logging.INFO,
                        os.path.join(log_folder_name, 'SYN-Par.log'))
    logger.info('Run name: {}'.format(args.run_name))

    if args.mode == 'train':
        train_dataloader, val_dataloader, test_dataloader = read_files(
            args, logger)
        logger.info('Creating vocab ...')

        voc1 = Voc(args.dataset + 'sents')
        voc2 = Voc(args.dataset + 'trees')
        voc_file = os.path.join('Models', args.run_name, 'vocab.p')

        if (os.path.exists(voc_file)):
            logger.info('Loading vocabulary from {}'.format(
                os.path.join('Models', args.run_name, 'vocab.p')))
            voc = pickle.load(open(voc_file, 'rb'))
        else:
            voc = create_vocab_dict(args, voc1, voc2, train_dataloader)

            logger.info('Vocab created with number of words = {}'.format(
                voc.nwords))
            logger.info('Saving Vocabulary file')

            with open(voc_file, 'wb') as f:
                pickle.dump(voc, f, protocol=pickle.HIGHEST_PROTOCOL)
            logger.info('Vocabulary file saved in {}'.format(
                os.path.join('Models', args.run_name, 'vocab.p')))
    else:
        config_file_name = os.path.join('Models', args.run_name, 'config.p')
        mode = args.mode
        batch_size = args.batch_size
        beam_width = args.beam_width
        gpu = args.gpu
        tree_height2 = 40
        use_glove = args.use_glove
        max_length = args.max_length
        datatype = args.datatype
        res_file = args.res_file
        dataset = args.dataset
        load_from_ep = args.load_from_ep
        run_name = args.run_name
        max_epochs = args.max_epochs

        with open(config_file_name, 'rb') as f:
            args = AttrDict(pickle.load(f))

        args.mode = mode
        args.gpu = gpu
        args.load_from_ep = load_from_ep
        args.dataset = dataset
        args.beam_width = beam_width
        args.gpu = gpu
        args.height_dec = height_dec
        args.tree_height2 = tree_height2
        args.use_glove = use_glove
        args.max_length = max_length
        args.datatype = datatype
        args.res_file = res_file
        args.run_name = run_name
        args.max_epochs = max_epochs

        test_dataloader = read_files(args, logger)

        logger.info('Loading Vocabulary file')
        with open(os.path.join('Models', args.run_name, 'vocab.p'), 'rb') as f:
            voc = pickle.load(f)
        logger.info('Vocabulary file Loaded from {}'.format(
            os.path.join('Models', args.run_name, 'vocab.p')))

    checkpoint = get_latest_checkpoint('Models', args.run_name, logger,
                                       args.load_from_ep)

    if args.mode == 'train':
        if checkpoint == None:
            logger.info('Starting a fresh training procedure')
            ep_offset = 0
            min_val_loss = 1e8
            max_val_bleu = 0.0
            config_file_name = os.path.join('Models', args.run_name,
                                            'config.p')
            if args.use_word2vec:
                logger.info(
                    'Over-writing emb_size to 300 because argument use_word2vec has been set to True'
                )
                args.emb_size = 300

            model = SYN_Par(args, voc, device, logger, experiment)
            with open(config_file_name, 'wb') as f:
                pickle.dump(vars(args), f, protocol=pickle.HIGHEST_PROTOCOL)
        else:
            config_file_name = os.path.join('Models', args.run_name,
                                            'config.p')
            debug = args.debug
            max_epochs = args.max_epochs
            with open(config_file_name, 'rb') as f:
                args = AttrDict(pickle.load(f))

            if args.use_word2vec:
                logger.info(
                    'Over-writing emb_size to 300 because argument use_word2vec has been set to True'
                )
                args.emb_size = 300
            args.use_ptr = use_ptr
            args.debug = debug
            args.cov_after_ep = cov_after_ep
            args.max_epochs = max_epochs

            model = SYN_Par(args, voc, device, logger, experiment)
            ep_offset, train_loss, min_val_loss, max_val_bleu, voc = load_checkpoint(
                model, args.mode, checkpoint, logger, device,
                args.pretrained_encoder)

            logger.info('Resuming Training From ')
            od = OrderedDict()
            if ep_offset is None and train_loss is None and min_val_loss is None and max_val_bleu is None:
                od['Epoch'] = 0
                od['Train_loss'] = 0.0
                od['Validation_loss'] = 0.0
                od['Validation_Bleu'] = 0.0
            else:
                od['Epoch'] = ep_offset
                od['Train_loss'] = train_loss
                od['Validation_loss'] = min_val_loss
                od['Validation_Bleu'] = float(max_val_bleu)
            print_log(logger, od)
            ep_offset += 1
            max_val_bleu = float(max_val_bleu)

        train(model,
              train_dataloader,
              val_dataloader,
              test_dataloader,
              voc,
              device,
              args,
              logger,
              ep_offset,
              min_val_loss,
              max_val_bleu,
              experiment=experiment)
    else:
        if checkpoint == None:
            logger.info('Cannot decode because of absence of checkpoints')
            sys.exit()
        else:
            config_file_name = os.path.join('Models', args.run_name,
                                            'config.p')
            beam_width = args.beam_width
            gpu = args.gpu
            tree_height2 = 40
            use_glove = args.use_glove
            max_length = args.max_length
            datatype = args.datatype
            res_file = args.res_file
            dataset = args.dataset
            load_from_ep = args.load_from_ep
            run_name = args.run_name
            max_epochs = args.max_epochs
            with open(config_file_name, 'rb') as f:
                args = AttrDict(pickle.load(f))

            args.load_from_ep = load_from_ep
            args.dataset = dataset
            args.beam_width = beam_width
            args.gpu = gpu
            args.height_dec = height_dec
            args.tree_height2 = tree_height2
            args.use_glove = use_glove
            args.max_length = max_length
            args.datatype = datatype
            args.res_file = res_file
            args.run_name = run_name
            args.max_epochs = max_epochs
            if args.use_word2vec:
                logger.info(
                    'Over-writing emb_size to 300 because argument use_word2vec has been set to True'
                )
                args.emb_size = 300

            model = SYN_Par(args, voc, device, logger, experiment)
            ep_offset, train_loss, min_val_loss, max_val_bleu, voc = load_checkpoint(
                model, args.mode, checkpoint, logger, device,
                args.pretrained_encoder)
            logger.info('Decoding from')
            od = OrderedDict()

            if ep_offset is None and train_loss is None and min_val_loss is None and max_val_bleu is None:
                od['Epoch'] = 0
                od['Train_loss'] = 0.0
                od['Validation_loss'] = 0.0
                od['Validation_Bleu'] = 0.0
            else:
                od['Epoch'] = ep_offset
                od['Train_loss'] = train_loss
                od['Validation_loss'] = min_val_loss
                od['Validation_Bleu'] = float(max_val_bleu)
            print_log(logger, od)
        '''
        refs, hyps, test_loss_epoch = validation(args, model, test_dataloader, voc, device, logger)
        refs = open('data/controlledgen/test_ref.txt').read().split('\n')[:-1]
        with open('temp_refs.txt', 'w') as f:
            f.write('\n'.join(refs))
        with open('temp_hyps.txt', 'w') as f:
            f.write('\n'.join(hyps))
        bleu_score_test = run_multi_bleu('temp_hyps.txt', 'temp_refs.txt')
        #bleu_score_test = bleu_scorer(refs, hyps)
        #os.remove('temp_hyps.txt')
        #os.remove('temp_refs.txt')

        logger.info('Test BLEU score: {}'.format(bleu_score_test))
        '''
        decode_greedy(model, test_dataloader, voc, device, args, logger)