Exemple #1
0
def main(args):
    conf = getattr(configs, 'config_' + args.model)()
    # Set the random seed manually for reproducibility.
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.seed)
    else:
        print("Note that our pre-trained models require CUDA to evaluate.")

    # Load data
    test_set = APIDataset(args.data_path + 'test.desc.h5',
                          args.data_path + 'test.apiseq.h5',
                          conf['max_sent_len'])
    test_loader = torch.utils.data.DataLoader(dataset=test_set,
                                              batch_size=1,
                                              shuffle=False,
                                              num_workers=1)
    vocab_api = load_dict(args.data_path + 'vocab.apiseq.json')
    vocab_desc = load_dict(args.data_path + 'vocab.desc.json')
    metrics = Metrics()

    # Load model checkpoints
    model = getattr(models, args.model)(conf)
    ckpt = f'./output/{args.model}/{args.expname}/{args.timestamp}/models/model_epo{args.reload_from}.pkl'
    model.load_state_dict(torch.load(ckpt))

    f_eval = open(
        f"./output/{args.model}/{args.expname}/results.txt".format(
            args.model, args.expname), "w")

    evaluate(model, metrics, test_loader, vocab_desc, vocab_api,
             args.n_samples, args.decode_mode, f_eval)
Exemple #2
0
 def __init__(self, conf):
     self.conf=conf
     self.path = conf['workdir']
             
     self.vocab_methname = load_dict(self.path+conf['vocab_name'])
     self.vocab_apiseq=load_dict(self.path+conf['vocab_api'])
     self.vocab_tokens=load_dict(self.path+conf['vocab_tokens'])
     self.vocab_desc=load_dict(self.path+conf['vocab_desc'])
     
     self.codevecs=[]
     self.codebase= []
     self.codebase_chunksize=2000000
     
     self.valid_set = None
def create_instance():
    tag_to_id = FLAGS.tag_to_id
    id_to_tag = {v: k for k, v in tag_to_id.items()}

    # 字典生成
    print "dict building......"
    if not isExists(FLAGS.dict_file):
        print "build dict starting..."
        train_file = read_conll_file(FLAGS.train_file)
        word_to_id, _ = word_mapping(train_file, FLAGS.min_freq)
        write_file(word_to_id, FLAGS.dict_file)
    else:
        print "build dict from pickle..."
        word_to_id = load_dict(FLAGS.dict_file)

    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.3)
    sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))

    with sess.as_default():
        print "begin for create model..."
        model = create_model(sess, word_to_id, id_to_tag) # just struct

        # load model
        model.logger.info("testing ner")
        ckpt = tf.train.get_checkpoint_state(FLAGS.model_path)
        model.logger.info("Reading model parameters from %s" % ckpt.model_checkpoint_path)
        model.saver.restore(sess, ckpt.model_checkpoint_path)

    return word_to_id, tag_to_id, id_to_tag, sess, model
def train(args):
    timestamp=datetime.now().strftime('%Y%m%d%H%M')    
    # LOG #
    logger = logging.getLogger(__name__)
    logging.basicConfig(level=logging.DEBUG, format="%(message)s")#,format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")
    tb_writer=None
    if args.visual:
        # make output directory if it doesn't already exist
        os.makedirs(f'./output/{args.model}/{args.expname}/{timestamp}/models', exist_ok=True)
        os.makedirs(f'./output/{args.model}/{args.expname}/{timestamp}/temp_results', exist_ok=True)
        fh = logging.FileHandler(f"./output/{args.model}/{args.expname}/{timestamp}/logs.txt")
                                      # create file handler which logs even debug messages
        logger.addHandler(fh)# add the handlers to the logger
        tb_writer = SummaryWriter(f"./output/{args.model}/{args.expname}/{timestamp}/logs/")
        # save arguments
        json.dump(vars(args), open(f'./output/{args.model}/{args.expname}/{timestamp}/args.json', 'w'))

    # Device #
    if args.gpu_id<0: 
        device = torch.device("cuda")
    else:
        device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() and args.gpu_id>-1 else "cpu")
    print(device)
    n_gpu = torch.cuda.device_count() if args.gpu_id<0 else 1
    print(f"num of gpus:{n_gpu}")
    # Set the random seed manually for reproducibility.
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    def save_model(model, epoch, timestamp):
        """Save model parameters to checkpoint"""
        os.makedirs(f'./output/{args.model}/{args.expname}/{timestamp}/models', exist_ok=True)
        ckpt_path=f'./output/{args.model}/{args.expname}/{timestamp}/models/model_epo{epoch}.pkl'
        print(f'Saving model parameters to {ckpt_path}')
        torch.save(model.state_dict(), ckpt_path)

    def load_model(model, epoch, timestamp):
        """Load parameters from checkpoint"""
        ckpt_path=f'./output/{args.model}/{args.expname}/{timestamp}/models/model_epo{epoch}.pkl'
        print(f'Loading model parameters from {ckpt_path}')
        model.load_state_dict(torch.load(checkpoint))

    config = getattr(configs, 'config_'+args.model)()

    ###############################################################################
    # Load dataset
    ###############################################################################
    train_set=APIDataset(args.data_path+'train.desc.h5', args.data_path+'train.apiseq.h5', config['max_sent_len'])
    valid_set=APIDataset(args.data_path+'test.desc.h5', args.data_path+'test.apiseq.h5', config['max_sent_len'])
    train_loader=torch.utils.data.DataLoader(dataset=train_set, batch_size=config['batch_size'], shuffle=True, num_workers=1)
    valid_loader=torch.utils.data.DataLoader(dataset=valid_set, batch_size=config['batch_size'], shuffle=True, num_workers=1)
    print("Loaded dataset!")

    ###############################################################################
    # Define the models
    ###############################################################################
    model = getattr(models, args.model)(config) 
    if args.reload_from>=0:
        load_model(model, args.reload_from)
    model=model.to(device)
    
    
    ###############################################################################
    # Prepare the Optimizer
    ###############################################################################
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
            {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
            {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]    
    optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=config['lr'], eps=config['adam_epsilon'])        
    scheduler = get_cosine_schedule_with_warmup(
            optimizer, num_warmup_steps=config['warmup_steps'], 
            num_training_steps=len(train_loader)*config['epochs']) # do not foget to modify the number when dataset is changed

    ###############################################################################
    # Training
    ###############################################################################
    logger.info("Training...")
    itr_global=1
    start_epoch=1 if args.reload_from==-1 else args.reload_from+1
    for epoch in range(start_epoch, config['epochs']+1):

        epoch_start_time = time.time()
        itr_start_time = time.time()

        # shuffle (re-define) dataset between epochs

        for batch in train_loader:# loop through all batches in training dataset
            model.train()
            batch_gpu = [tensor.to(device) for tensor in batch]
            loss = model(*batch_gpu)  
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), config['clip'])
            optimizer.step()
            scheduler.step()
            model.zero_grad()

            if itr_global % args.log_every == 0:
                elapsed = time.time() - itr_start_time
                log = '%s-%s|@gpu%d epo:[%d/%d] iter:%d step_time:%ds loss:%f'\
                %(args.model, args.expname, args.gpu_id, epoch, config['epochs'],itr_global, elapsed, loss)
                if args.visual:
                        tb_writer.add_scalar('loss', loss, itr_global)
                logger.info(log)

                itr_start_time = time.time()   

            if itr_global % args.valid_every == 0:
             
                model.eval()
                loss_records={}

                for batch in valid_loader:
                    batch_gpu = [tensor.to(device) for tensor in batch]
                    with torch.no_grad():
                        valid_loss = model.valid(*batch_gpu)    
                    for loss_name, loss_value in valid_loss.items():
                        v=loss_records.get(loss_name, [])
                        v.append(loss_value)
                        loss_records[loss_name]=v

                log = 'Validation '
                for loss_name, loss_values in loss_records.items():
                    log = log + loss_name + ':%.4f  '%(np.mean(loss_values))
                    if args.visual:
                        tb_writer.add_scalar(loss_name, np.mean(loss_values), itr_global)                 
                logger.info(log)    

            itr_global+=1        

            if itr_global % args.eval_every == 0:  # evaluate the model in the develop set
                model.eval()      
                save_model(model, itr_global, timestamp) # save model after each epoch
                
                valid_loader=torch.utils.data.DataLoader(dataset=valid_set, batch_size=1, shuffle=False, num_workers=1)
                vocab_api = load_dict(args.data_path+'vocab.apiseq.json')
                vocab_desc = load_dict(args.data_path+'vocab.desc.json')
                metrics=Metrics()
                
                os.makedirs(f'./output/{args.model}/{args.expname}/{timestamp}/temp_results', exist_ok=True)
                f_eval = open(f"./output/{args.model}/{args.expname}/{timestamp}/temp_results/iter{itr_global}.txt", "w")
                repeat = 1
                decode_mode = 'sample'
                recall_bleu, prec_bleu = evaluate(model, metrics, valid_loader, vocab_desc, vocab_api, repeat, decode_mode, f_eval)

                if args.visual:
                    tb_writer.add_scalar('recall_bleu', recall_bleu, itr_global)
                    tb_writer.add_scalar('prec_bleu', prec_bleu, itr_global)
                

        # end of epoch ----------------------------
        model.adjust_lr()
Exemple #5
0
if __name__ == '__main__':
    args = parse_args()
    device = torch.device(
        f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu")
    config = getattr(configs, 'config_' + args.model)()

    ##### Define model ######
    logger.info('Constructing Model..')
    model = getattr(models, args.model)(config)  #initialize the model
    ckpt = f'./output/{args.model}/{args.dataset}/models/step{args.reload_from}.h5'
    model.load_state_dict(torch.load(ckpt, map_location=device))

    data_path = args.data_path + args.dataset + '/'

    vocab_desc = load_dict(data_path + config['vocab_desc'])
    codebase = load_codebase(data_path + config['use_codebase'],
                             args.chunk_size)
    codevecs = load_codevecs(data_path + config['use_codevecs'],
                             args.chunk_size)
    assert len(codebase)==len(codevecs), \
         "inconsistent number of chunks, check whether the specified files for codebase and code vectors are correct!"

    while True:
        try:
            query = input('Input Query: ')
            if query == "exit":
                break
            n_results = int(input('How many results? '))
        except Exception:
            print("Exception while parsing your input:")
Exemple #6
0
    def __getitem__(self, index):
        #--------------------------------------negcodes----------------------------------
        #1
        rand_offsetcode = random.randint(0, self.len - 1)
        negast_num = self.data_set[rand_offsetcode]['ast_num']
        negast = read_pickle(self.ast_path + negast_num)
        negseq, negrel_par, negrel_bro, negrel_semantic, negsemantic_convert_matrix, negsemantic_mask = traverse_tree_to_generate_matrix(
            negast, self.max_ast_size, self.k, self.max_simple_name_size)
        negseq_id = [
            self.ast2id[x] if x in self.ast2id else self.ast2id['<UNK>']
            for x in negseq
        ]
        negseq_tensor = torch.LongTensor(negseq_id)
        #2
        rand_offsetcode2 = random.randint(0, self.len - 1)
        negast_num2 = self.data_set[rand_offsetcode2]['ast_num']
        negast2 = read_pickle(self.ast_path + negast_num2)
        negseq2, negrel_par2, negrel_bro2, negrel_semantic2, negsemantic_convert_matrix2, negsemantic_mask2 = traverse_tree_to_generate_matrix(
            negast2, self.max_ast_size, self.k, self.max_simple_name_size)
        negseq_id2 = [
            self.ast2id[x] if x in self.ast2id else self.ast2id['<UNK>']
            for x in negseq2
        ]
        negseq_tensor2 = torch.LongTensor(negseq_id2)
        #3
        rand_offsetcode3 = random.randint(0, self.len - 1)
        negast_num3 = self.data_set[rand_offsetcode3]['ast_num']
        negast3 = read_pickle(self.ast_path + negast_num3)
        negseq3, negrel_par3, negrel_bro3, negrel_semantic3, negsemantic_convert_matrix3, negsemantic_mask3 = traverse_tree_to_generate_matrix(
            negast3, self.max_ast_size, self.k, self.max_simple_name_size)
        negseq_id3 = [
            self.ast2id[x] if x in self.ast2id else self.ast2id['<UNK>']
            for x in negseq3
        ]
        negseq_tensor3 = torch.LongTensor(negseq_id3)

        #------------------------------poscode---------------------------------
        data = self.data_set[index]
        ast_num = data['ast_num']
        nl = data['nl']
        ast = read_pickle(self.ast_path + ast_num)
        seq, rel_par, rel_bro, rel_semantic, semantic_convert_matrix, semantic_mask = traverse_tree_to_generate_matrix(
            ast, self.max_ast_size, self.k, self.max_simple_name_size)
        seq_id = [
            self.ast2id[x] if x in self.ast2id else self.ast2id['<UNK>']
            for x in seq
        ]
        #nl_id = [self.nl2id[x] if x in self.nl2id else self.nl2id['<UNK>'] for x in nl]
        """to tensor"""
        seq_tensor = torch.LongTensor(seq_id)
        #nl_tensor = torch.LongTensor(pad_seq(nl_id, self.max_comment_size).long())
        #print(nl)
        #print("\n")

        if self.training:
            #good desc
            data_path = './data/vocab.desc.json'
            vocab_desc = load_dict(data_path)
            nl_len = len(nl)
            for i in range(nl_len):
                nl[i] = vocab_desc.get(nl[i], 3)
            #print(nl)
            #print("\n")
            #nl2index, nl_len = sent2indexes(nl, vocab_desc, 30)
            nl2long = np.array(nl).astype(np.long)
            good_desc_len = min(int(nl_len), self.max_desc_len)
            good_desc = nl2long
            good_desc = self.pad_seq(good_desc, self.max_desc_len)

            #bad_desc
            rand_offset = random.randint(0, self.len - 1)
            bad_seq = self.data_set[rand_offset]['nl']
            bad_len = len(bad_seq)
            for i in range(bad_len):
                bad_seq[i] = vocab_desc.get(bad_seq[i], 3)
            bad2long = np.array(bad_seq).astype(np.long)
            #bad_index, bad_len = sent2indexes(bad_seq, vocab_desc, 30)
            bad_desc_len = min(int(bad_len), self.max_desc_len)
            bad_desc = bad2long

            bad_desc = self.pad_seq(bad_desc, self.max_desc_len)

            return seq_tensor, rel_par, rel_bro, rel_semantic, negseq_tensor, negrel_par, negrel_bro, negrel_semantic, negseq_tensor2, negrel_par2, negrel_bro2, negrel_semantic2, negseq_tensor3, negrel_par3, negrel_bro3, negrel_semantic3, good_desc, good_desc_len, bad_desc, bad_desc_len
        return seq_tensor, rel_par, rel_bro, rel_semantic
        args.model, args.expname, epoch)
    print(f'Loading model parameters from {ckpt_path}')
    model.load_state_dict(torch.load(checkpoint))


config = getattr(configs, 'config_' + args.model)()

###############################################################################
# Load data
###############################################################################
train_set = APIDataset(args.data_path + 'train.desc.h5',
                       args.data_path + 'train.apiseq.h5', config['maxlen'])
valid_set = APIDataset(args.data_path + 'test.desc.h5',
                       args.data_path + 'test.apiseq.h5', config['maxlen'])

vocab_api = load_dict(args.data_path + 'vocab.apiseq.json')
vocab_desc = load_dict(args.data_path + 'vocab.desc.json')
n_tokens = len(vocab_api)

metrics = Metrics()

print("Loaded data!")

###############################################################################
# Define the models
###############################################################################

model = getattr(model, args.model)(config, n_tokens)
if args.reload_from >= 0:
    load_model(model, args.reload_from)
Exemple #8
0
def train(args):
    timestamp = datetime.now().strftime('%Y%m%d%H%M')
    # LOG #
    logger = logging.getLogger(__name__)
    logging.basicConfig(
        level=logging.DEBUG, format="%(message)s"
    )  # ,format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")
    tb_writer = None
    if args.visual:
        # make output directory if it doesn't already exist
        os.makedirs(f'./output/{args.model}/{args.expname}/{timestamp}/models',
                    exist_ok=True)
        os.makedirs(
            f'./output/{args.model}/{args.expname}/{timestamp}/temp_results',
            exist_ok=True)
        fh = logging.FileHandler(
            f"./output/{args.model}/{args.expname}/{timestamp}/logs.txt")
        # create file handler which logs even debug messages
        logger.addHandler(fh)  # add the handlers to the logger
        tb_writer = SummaryWriter(
            f"./output/{args.model}/{args.expname}/{timestamp}/logs/")
        # save arguments
        json.dump(
            vars(args),
            open(f'./output/{args.model}/{args.expname}/{timestamp}/args.json',
                 'w'))

    # Device #
    if args.gpu_id < 0:
        device = torch.device("cuda")
    else:
        device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available(
        ) and args.gpu_id > -1 else "cpu")
    print(device)
    n_gpu = torch.cuda.device_count() if args.gpu_id < 0 else 1
    print(f"num of gpus:{n_gpu}")
    # Set the random seed manually for reproducibility.
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    def save_model(model, itr, timestamp):
        """Save model parameters to checkpoint"""
        os.makedirs(f'./output/transformer/{args.expname}/{timestamp}/models',
                    exist_ok=True)
        ckpt_path = f'./output/transformer/{args.expname}/{timestamp}/models/model_itr{itr}.pkl'
        print(f'Saving model parameters to {ckpt_path}')
        torch.save(model.state_dict(), ckpt_path)

    def load_model(model, itr, timestamp):
        """Load parameters from checkpoint"""
        ckpt_path = f'./output/transformer/{args.expname}/202012211035/models/model_itr{itr}.pkl'
        print(f'Loading model parameters from {ckpt_path}')
        model.load_state_dict(torch.load(ckpt_path))

    def make_mask(src_input, trg_input, device):
        pad_id = 0
        e_mask = (src_input != pad_id).unsqueeze(1)  # (B, 1, L)
        d_mask = (trg_input != pad_id).unsqueeze(1)  # (B, 1, L)
        nopeak_mask = torch.ones([1, 50, 50], dtype=torch.bool)  # (1, L, L)
        nopeak_mask = torch.tril(nopeak_mask).to(
            device)  # (1, L, L) to triangular shape
        d_mask = d_mask & nopeak_mask  # (B, L, L) padding false
        return e_mask, d_mask

    config = getattr(configs, 'config_' + args.model)()

    ###############################################################################
    # Load data
    ###############################################################################
    train_set = APIDataset(args.data_path + 'train.desc.h5',
                           args.data_path + 'train.apiseq.h5',
                           config['max_sent_len'])
    valid_set = APIDataset(args.data_path + 'test.desc.h5',
                           args.data_path + 'test.apiseq.h5',
                           config['max_sent_len'])
    train_loader = torch.utils.data.DataLoader(dataset=train_set,
                                               batch_size=config['batch_size'],
                                               shuffle=True,
                                               num_workers=1)
    valid_loader = torch.utils.data.DataLoader(dataset=valid_set,
                                               batch_size=config['batch_size'],
                                               shuffle=True,
                                               num_workers=1)
    print("Loaded data!")

    ###############################################################################
    # Define the models
    ###############################################################################
    model = Transformer(10000, 10000).to(device)
    print(model)
    ###############################################################################
    # Prepare the Optimizer
    ###############################################################################
    no_decay = ['bias', 'LayerNorm.weight']
    optim = torch.optim.Adam(model.parameters(), lr=1e-4)
    # do not forget to modify the number when dataset is changed
    criterion = torch.nn.NLLLoss()
    ###############################################################################
    # Training
    ###############################################################################
    logger.info("Training...")
    itr_global = 1
    start_epoch = 1 if args.reload_from == -1 else args.reload_from + 1
    for epoch in range(start_epoch, config['epochs'] + 1):
        epoch_start_time = time.time()
        itr_start_time = time.time()
        # shuffle (re-define) data between epochs
        for batch in train_loader:  # loop through all batches in training data
            model.train()
            batch_gpu = [tensor.to(device) for tensor in batch]
            old_src_input, src_lens, trg_input, src_lens = batch_gpu
            src_input = torch.zeros_like(old_src_input)
            src_input[:, :-1] = old_src_input[:, 1:50]
            trg_output = torch.zeros_like(trg_input)
            trg_output[:, :-1] = trg_input[:, 1:50]
            trg_input[trg_input == 2] = 0
            # print("src_input:\n", src_input[0])
            # print("trg_input:\n", trg_input[0])
            # print("trg_output:\n", trg_output[0])

            e_mask, d_mask = make_mask(src_input, trg_input, device)

            output = model(src_input, trg_input, e_mask, d_mask)

            trg_output_shape = trg_output.shape
            optim.zero_grad()
            loss = criterion(
                output.view(-1, 10000),
                trg_output.view(trg_output_shape[0] * trg_output_shape[1]))
            loss.backward()
            optim.step()

            del src_input, trg_input, trg_output, e_mask, d_mask, output, old_src_input, src_lens
            torch.cuda.empty_cache()
            # after one batch,log
            if itr_global % args.log_every == 0:

                elapsed = time.time() - itr_start_time
                log = 'Transformer-%s|@gpu%d epo:[%d/%d] iter:%d step_time:%ds loss:%f' \
                % (args.expname, args.gpu_id, epoch, config['epochs'], itr_global, elapsed, loss)
                logger.info(log)
                itr_start_time = time.time()

            if itr_global % args.eval_every == 0:
                # evaluate bleu score
                model.eval()
                save_model(model, itr_global, timestamp)
                valid_loader = torch.utils.data.DataLoader(dataset=valid_set,
                                                           batch_size=1,
                                                           shuffle=False,
                                                           num_workers=1)
                vocab_api = load_dict(args.data_path + 'vocab.apiseq.json')
                vocab_desc = load_dict(args.data_path + 'vocab.desc.json')
                metrics = Metrics()
                os.makedirs(
                    f'./output/transformer/{args.expname}/{timestamp}/temp_results',
                    exist_ok=True)
                f_eval = open(
                    f"./output/transformer/{args.expname}/{timestamp}/temp_results/iter{itr_global}.txt",
                    "w")
                evaluate_transformer(model, metrics, valid_loader, vocab_desc,
                                     vocab_api, f_eval)

            itr_global += 1
Exemple #9
0
def train(args):
    timestamp = datetime.now().strftime('%Y%m%d%H%M')
    # LOG #
    logger = logging.getLogger(__name__)
    logging.basicConfig(
        level=logging.DEBUG, format="%(message)s"
    )  #,format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")
    tb_writer = None
    if args.visual:
        # make output directory if it doesn't already exist
        os.makedirs(f'./output/{args.model}/{args.expname}/{timestamp}/models',
                    exist_ok=True)
        os.makedirs(
            f'./output/{args.model}/{args.expname}/{timestamp}/temp_results',
            exist_ok=True)
        fh = logging.FileHandler(
            f"./output/{args.model}/{args.expname}/{timestamp}/logs.txt")
        # create file handler which logs even debug messages
        logger.addHandler(fh)  # add the handlers to the logger
        tb_writer = SummaryWriter(
            f"./output/{args.model}/{args.expname}/{timestamp}/logs/")
        # save arguments
        json.dump(
            vars(args),
            open(f'./output/{args.model}/{args.expname}/{timestamp}/args.json',
                 'w'))

    # Device #
    if args.gpu_id < 0:
        device = torch.device("cuda")
    else:
        device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available(
        ) and args.gpu_id > -1 else "cpu")
    print(device)
    n_gpu = torch.cuda.device_count() if args.gpu_id < 0 else 1
    print(f"num of gpus:{n_gpu}")
    # Set the random seed manually for reproducibility.
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    def save_model(model, epoch, timestamp):
        """Save model parameters to checkpoint"""
        os.makedirs(f'./output/{args.model}/{args.expname}/{timestamp}/models',
                    exist_ok=True)
        ckpt_path = f'./output/{args.model}/{args.expname}/{timestamp}/models/model_epo{epoch}.pkl'
        print(f'Saving model parameters to {ckpt_path}')
        torch.save(model.state_dict(), ckpt_path)

    def load_model(model, epoch, timestamp):
        """Load parameters from checkpoint"""
        ckpt_path = f'./output/{args.model}/{args.expname}/{timestamp}/models/model_epo{epoch}.pkl'
        print(f'Loading model parameters from {ckpt_path}')
        model.load_state_dict(torch.load(checkpoint))

    config = getattr(configs, 'config_' + args.model)()

    ###############################################################################
    # Load data
    ###############################################################################
    train_set = APIDataset(args.data_path + 'train.desc.h5',
                           args.data_path + 'train.apiseq.h5',
                           config['max_sent_len'])
    valid_set = APIDataset(args.data_path + 'test.desc.h5',
                           args.data_path + 'test.apiseq.h5',
                           config['max_sent_len'])
    print("Loaded data!")

    ###############################################################################
    # Define the models
    ###############################################################################
    model = getattr(models, args.model)(config)
    if args.reload_from >= 0:
        load_model(model, args.reload_from)
    model = model.to(device)

    ###############################################################################
    # Training
    ###############################################################################
    logger.info("Training...")
    itr_global = 1
    start_epoch = 1 if args.reload_from == -1 else args.reload_from + 1
    for epoch in range(start_epoch, config['epochs'] + 1):

        epoch_start_time = time.time()
        itr_start_time = time.time()

        # shuffle (re-define) data between epochs
        train_loader = torch.utils.data.DataLoader(
            dataset=train_set,
            batch_size=config['batch_size'],
            shuffle=True,
            num_workers=1)
        train_data_iter = iter(train_loader)
        n_iters = train_data_iter.__len__()

        itr = 1
        while True:  # loop through all batches in training data
            model.train()
            try:
                descs, apiseqs, desc_lens, api_lens = train_data_iter.next()
            except StopIteration:  # end of epoch
                break
            batch = [
                tensor.to(device)
                for tensor in [descs, desc_lens, apiseqs, api_lens]
            ]
            loss_AE = model.train_AE(*batch)

            if itr % args.log_every == 0:
                elapsed = time.time() - itr_start_time
                log = '%s-%s|@gpu%d epo:[%d/%d] iter:[%d/%d] step_time:%ds elapsed:%s \n                      '\
                %(args.model, args.expname, args.gpu_id, epoch, config['epochs'],
                         itr, n_iters, elapsed, timeSince(epoch_start_time,itr/n_iters))
                for loss_name, loss_value in loss_AE.items():
                    log = log + loss_name + ':%.4f ' % (loss_value)
                    if args.visual:
                        tb_writer.add_scalar(loss_name, loss_value, itr_global)
                logger.info(log)

                itr_start_time = time.time()

            if itr % args.valid_every == 0:
                valid_loader = torch.utils.data.DataLoader(
                    dataset=valid_set,
                    batch_size=config['batch_size'],
                    shuffle=True,
                    num_workers=1)
                model.eval()
                loss_records = {}

                for descs, apiseqs, desc_lens, api_lens in valid_loader:
                    batch = [
                        tensor.to(device)
                        for tensor in [descs, desc_lens, apiseqs, api_lens]
                    ]
                    valid_loss = model.valid(*batch)
                    for loss_name, loss_value in valid_loss.items():
                        v = loss_records.get(loss_name, [])
                        v.append(loss_value)
                        loss_records[loss_name] = v

                log = 'Validation '
                for loss_name, loss_values in loss_records.items():
                    log = log + loss_name + ':%.4f  ' % (np.mean(loss_values))
                    if args.visual:
                        tb_writer.add_scalar(loss_name, np.mean(loss_values),
                                             itr_global)
                logger.info(log)

            itr += 1
            itr_global += 1

            if itr_global % args.eval_every == 0:  # evaluate the model in the develop set
                model.eval()
                save_model(model, itr_global,
                           timestamp)  # save model after each epoch

                valid_loader = torch.utils.data.DataLoader(dataset=valid_set,
                                                           batch_size=1,
                                                           shuffle=False,
                                                           num_workers=1)
                vocab_api = load_dict(args.data_path + 'vocab.apiseq.json')
                vocab_desc = load_dict(args.data_path + 'vocab.desc.json')
                metrics = Metrics()

                os.makedirs(
                    f'./output/{args.model}/{args.expname}/{timestamp}/temp_results',
                    exist_ok=True)
                f_eval = open(
                    f"./output/{args.model}/{args.expname}/{timestamp}/temp_results/iter{itr_global}.txt",
                    "w")
                repeat = 1
                decode_mode = 'sample'
                recall_bleu, prec_bleu = evaluate(model, metrics, valid_loader,
                                                  vocab_desc, vocab_api,
                                                  repeat, decode_mode, f_eval)

                if args.visual:
                    tb_writer.add_scalar('recall_bleu', recall_bleu,
                                         itr_global)
                    tb_writer.add_scalar('prec_bleu', prec_bleu, itr_global)

        # end of epoch ----------------------------
        model.adjust_lr()