示例#1
0
def train_v5():
    print("Start training hierarchical RNN model")
    # ---------------------------------------------------------------------------------- #
    args = {}
    args['use_gpu']        = True
    args['num_utterances'] = 1500  # max no. utterance in a meeting
    args['num_words']      = 64    # max no. words in an utterance
    args['summary_length'] = 300   # max no. words in a summary
    args['summary_type']   = 'short'   # long or short summary
    args['vocab_size']     = 30522 # BERT tokenizer
    args['embedding_dim']   = 256   # word embeeding dimension
    args['rnn_hidden_size'] = 512 # RNN hidden size

    args['dropout']        = 0.1
    args['num_layers_enc'] = 2    # in total it's num_layers_enc*2 (word/utt)
    args['num_layers_dec'] = 1

    args['batch_size']      = 1
    args['update_nbatches'] = 2
    args['num_epochs']      = 20
    args['random_seed']     = 777
    args['best_val_loss']     = 1e+10
    args['val_batch_size']    = 1 # 1 for now --- evaluate ROUGE
    args['val_stop_training'] = 5

    args['lr']         = 1.0
    args['adjust_lr']  = True     # if True overwrite the learning rate above
    args['initial_lr'] = 0.01       # lr = lr_0*step^(-decay_rate)
    args['decay_rate'] = 0.5
    args['label_smoothing'] = 0.1

    args['a_da']  = 0.2
    args['a_ext'] = 0.2
    args['a_cov'] = 0.0
    args['a_div'] = 1.0

    args['memory_utt'] = False

    args['model_save_dir'] = "/home/alta/summary/pm574/summariser1/lib/trained_models2/"
    args['load_model'] = "/home/alta/summary/pm574/summariser1/lib/trained_models2/model-HGRUV5_CNNDM_FEB26A-ep12-bn0" # add .pt later
    # args['load_model'] = "/home/alta/summary/pm574/summariser1/lib/trained_models2/model-HGRUV5_FEB28A-ep6"
    # args['load_model'] = "/home/alta/summary/pm574/summariser1/lib/trained_models2/model-HGRUV5MEM_APR8A-ep1"
    # args['load_model'] = None
    args['model_name'] = 'HGRUV5_APR16H5'
    # ---------------------------------------------------------------------------------- #
    print_config(args)

    if args['use_gpu']:
        if 'X_SGE_CUDA_DEVICE' in os.environ: # to run on CUED stack machine
            print('running on the stack... 1 GPU')
            cuda_device = os.environ['X_SGE_CUDA_DEVICE']
            print('X_SGE_CUDA_DEVICE is set to {}'.format(cuda_device))
            os.environ['CUDA_VISIBLE_DEVICES'] = cuda_device
        else:
            print('running locally...')
            os.environ["CUDA_VISIBLE_DEVICES"] = '0' # choose the device (GPU) here
        device = 'cuda'
    else:
        device = 'cpu'
    print("device = {}".format(device))

    # random seed
    random.seed(args['random_seed'])
    torch.manual_seed(args['random_seed'])
    np.random.seed(args['random_seed'])

    train_data = load_ami_data('train')
    valid_data = load_ami_data('valid')
    # make the training data 100
    random.shuffle(valid_data)
    train_data.extend(valid_data[:6])
    valid_data = valid_data[6:]

    model = EncoderDecoder(args, device=device)
    print(model)
    NUM_DA_TYPES = len(DA_MAPPING)
    da_labeller = DALabeller(args['rnn_hidden_size'], NUM_DA_TYPES, device)
    print(da_labeller)
    ext_labeller = EXTLabeller(args['rnn_hidden_size'], device)
    print(ext_labeller)

    # Load model if specified (path to pytorch .pt)
    if args['load_model'] != None:
        model_path = args['load_model'] + '.pt'
        try:
            model.load_state_dict(torch.load(model_path))
        except RuntimeError: # need to remove module
            # Main model
            model_state_dict = torch.load(model_path)
            new_model_state_dict = OrderedDict()
            for key in model_state_dict.keys():
                new_model_state_dict[key.replace("module.","")] = model_state_dict[key]

            if args['memory_utt']:
                model.load_state_dict(new_model_state_dict, strict=False)
            else:
                model.load_state_dict(new_model_state_dict)

        model.train()
        print("Loaded model from {}".format(args['load_model']))
    else:
        print("Train a new model")


    # Hyperparameters
    BATCH_SIZE = args['batch_size']
    NUM_EPOCHS = args['num_epochs']
    VAL_BATCH_SIZE = args['val_batch_size']
    VAL_STOP_TRAINING = args['val_stop_training']

    if args['label_smoothing'] > 0.0:
        criterion = LabelSmoothingLoss(num_classes=args['vocab_size'],
                        smoothing=args['label_smoothing'], reduction='none')
    else:
        criterion = nn.NLLLoss(reduction='none')

    da_criterion = nn.NLLLoss(reduction='none')
    ext_criterion = nn.BCELoss(reduction='none')

    # ONLY train the momory part #
    # for name, param in model.named_parameters():
    #     if "utt" in name:
    #         pass
    #     else:
    #         param.requires_grad = False
    # optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),lr=args['lr'],betas=(0.9,0.999),eps=1e-08,weight_decay=0)
    # -------------------------- #

    optimizer = optim.Adam(model.parameters(),lr=args['lr'],betas=(0.9,0.999),eps=1e-08,weight_decay=0)
    optimizer.zero_grad()

    # DA labeller optimiser
    da_optimizer = optim.Adam(da_labeller.parameters(),lr=args['lr'],betas=(0.9,0.999),eps=1e-08,weight_decay=0)
    da_optimizer.zero_grad()

    # extractive labeller optimiser
    ext_optimizer = optim.Adam(ext_labeller.parameters(),lr=args['lr'],betas=(0.9,0.999),eps=1e-08,weight_decay=0)
    ext_optimizer.zero_grad()

    # validation losses
    best_val_loss = args['best_val_loss']
    best_epoch    = 0
    stop_counter  = 0

    training_step = 0

    for epoch in range(NUM_EPOCHS):
        print("======================= Training epoch {} =======================".format(epoch))
        num_train_data = len(train_data)
        # num_batches = int(num_train_data/BATCH_SIZE) + 1
        num_batches = int(num_train_data/BATCH_SIZE)
        print("num_batches = {}".format(num_batches))

        print("shuffle train data")
        random.shuffle(train_data)

        idx = 0

        for bn in range(num_batches):

            input, u_len, w_len, target, tgt_len, _, dialogue_acts, extractive_label = get_a_batch(
                    train_data, idx, BATCH_SIZE,
                    args['num_utterances'], args['num_words'],
                    args['summary_length'], args['summary_type'], device)

            # decoder target
            decoder_target, decoder_mask = shift_decoder_target(target, tgt_len, device, mask_offset=True)
            decoder_target = decoder_target.view(-1)
            decoder_mask = decoder_mask.view(-1)

            decoder_output, u_output, attn_scores, cov_scores, u_attn_scores = model(input, u_len, w_len, target)

            loss = criterion(decoder_output.view(-1, args['vocab_size']), decoder_target)
            loss = (loss * decoder_mask).sum() / decoder_mask.sum()

            # COVLOSS:
            # loss_cov = compute_covloss(attn_scores, cov_scores)
            # loss_cov = (loss_cov.view(-1) * decoder_mask).sum() / decoder_mask.sum()

            # Diversity Loss (4):
            intra_div, inter_div = diverisity_loss(u_attn_scores, decoder_target, u_len, tgt_len)
            if inter_div == 0:
                loss_div = 0
            else:
                loss_div = intra_div/inter_div


            # multitask(2): dialogue act prediction
            da_output = da_labeller(u_output)
            loss_utt_mask = length2mask(u_len, BATCH_SIZE, args['num_utterances'], device)
            loss_da = da_criterion(da_output.view(-1, NUM_DA_TYPES), dialogue_acts.view(-1)).view(BATCH_SIZE, -1)
            loss_da = (loss_da * loss_utt_mask).sum() / loss_utt_mask.sum()

            # multitask(3): extractive label prediction
            ext_output = ext_labeller(u_output).squeeze(-1)
            loss_ext = ext_criterion(ext_output, extractive_label)
            loss_ext = (loss_ext * loss_utt_mask).sum() / loss_utt_mask.sum()

            # total_loss = loss + args['a_da']*loss_da + args['a_ext']*loss_ext + args['a_cov']*loss_cov
            total_loss = loss + args['a_da']*loss_da + args['a_ext']*loss_ext + args['a_div']*loss_div
            # total_loss = loss + args['a_da']*loss_da + args['a_ext']*loss_ext
            # total_loss = loss + args['a_div']*loss_div

            total_loss.backward()
            # loss.backward()

            idx += BATCH_SIZE

            if bn % args['update_nbatches'] == 0:
                # gradient_clipping
                max_norm = 0.5
                nn.utils.clip_grad_norm_(model.parameters(), max_norm)
                nn.utils.clip_grad_norm_(da_labeller.parameters(), max_norm)
                nn.utils.clip_grad_norm_(ext_labeller.parameters(), max_norm)
                # update the gradients
                if args['adjust_lr']:
                    adjust_lr(optimizer, args['initial_lr'], args['decay_rate'], training_step)
                    adjust_lr(da_optimizer, args['initial_lr'], args['decay_rate'], training_step)
                    adjust_lr(ext_optimizer, args['initial_lr'], args['decay_rate'], training_step)
                optimizer.step()
                optimizer.zero_grad()
                da_optimizer.step()
                da_optimizer.zero_grad()
                ext_optimizer.step()
                ext_optimizer.zero_grad()
                training_step += args['batch_size']*args['update_nbatches']

            if bn % 1 == 0:
                print("[{}] batch {}/{}: loss = {:.5f} | loss_div = {:.5f} | loss_da = {:.5f} | loss_ext = {:.5f}".
                    format(str(datetime.now()), bn, num_batches, loss, loss_div, loss_da, loss_ext))
                # print("[{}] batch {}/{}: loss = {:.5f} | loss_da = {:.5f} | loss_ext = {:.5f}".
                #     format(str(datetime.now()), bn, num_batches, loss, loss_da, loss_ext))
                # print("[{}] batch {}/{}: loss = {:.5f} | loss_div = {:.5f}".
                    # format(str(datetime.now()), bn, num_batches, loss, loss_div))
                # print("[{}] batch {}/{}: loss = {:.5f}".format(str(datetime.now()), bn, num_batches, loss))
                sys.stdout.flush()

            if bn % 10 == 0:

                print("======================== GENERATED SUMMARY ========================")
                print(bert_tokenizer.decode(torch.argmax(decoder_output[0], dim=-1).cpu().numpy()[:tgt_len[0]]))
                print("======================== REFERENCE SUMMARY ========================")
                print(bert_tokenizer.decode(decoder_target.view(BATCH_SIZE,args['summary_length'])[0,:tgt_len[0]].cpu().numpy()))

            if bn == 0: # e.g. eval every epoch
                # ---------------- Evaluate the model on validation data ---------------- #
                print("Evaluating the model at epoch {} step {}".format(epoch, bn))
                print("learning_rate = {}".format(optimizer.param_groups[0]['lr']))

                # switch to evaluation mode
                model.eval()
                da_labeller.eval()
                ext_labeller.eval()

                with torch.no_grad():
                    avg_val_loss = evaluate(model, valid_data, VAL_BATCH_SIZE, args, device, use_rouge=True)
                    # avg_val_loss = evaluate_greedy(model, valid_data, VAL_BATCH_SIZE, args, device)

                print("avg_val_loss_per_token = {}".format(avg_val_loss))

                # switch to training mode
                model.train()
                da_labeller.train()
                ext_labeller.train()
                # ------------------- Save the model OR Stop training ------------------- #
                state = {
                    'epoch': epoch, 'bn': bn,
                    'training_step': training_step,
                    'model': model.state_dict(),
                    'da_labeller': da_labeller.state_dict(),
                    'ext_labeller': ext_labeller.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'best_val_loss': best_val_loss
                }
                if avg_val_loss < best_val_loss:
                    stop_counter = 0
                    best_val_loss = avg_val_loss
                    best_epoch = epoch

                    savepath = args['model_save_dir']+"model-{}-ep{}.pt".format(args['model_name'], 999) # 999 = best
                    torch.save(state, savepath)
                    print("Model improved & saved at {}".format(savepath))
                else:
                    print("Model not improved #{}".format(stop_counter))
                    savepath = args['model_save_dir']+"model-{}-ep{}.pt".format(args['model_name'], 000) # 000 = current
                    torch.save(state, savepath)
                    print("Model NOT improved & saved at {}".format(savepath))
                    if stop_counter < VAL_STOP_TRAINING:
                        print("Just continue training ---- no loading old weights")
                        stop_counter += 1
                    else:
                        print("Model has not improved for {} times! Stop training.".format(VAL_STOP_TRAINING))
                        return

    print("End of training hierarchical RNN model")
示例#2
0
def train_mbr():
    print("Start training hierarchical RNN model")
    # ---------------------------------------------------------------------------------- #
    args = {}
    args['use_gpu'] = True
    args['num_utterances'] = 1200  # max no. utterance in a meeting
    args['num_words'] = 50  # max no. words in an utterance
    args['summary_length'] = 300  # max no. words in a summary
    args['summary_type'] = 'short'  # long or short summary
    args['vocab_size'] = 30522  # BERT tokenizer
    args['embedding_dim'] = 256  # word embeeding dimension
    args['rnn_hidden_size'] = 512  # RNN hidden size

    args['dropout'] = 0.1
    args['num_layers_enc'] = 2  # in total it's num_layers_enc*3 (word/utt/seg)
    args['num_layers_dec'] = 1

    args['batch_size'] = 1
    args['update_nbatches'] = 1
    args['num_epochs'] = 10
    args['random_seed'] = 88
    args['best_val_loss'] = 1e+10
    args['val_batch_size'] = 1  # 1 for now --- evaluate ROUGE
    args['val_stop_training'] = 50

    args['lr'] = 1e-3
    args['adjust_lr'] = False  # if True overwrite the learning rate above
    args['initial_lr'] = 0.2  # lr = lr_0*step^(-decay_rate)
    args['decay_rate'] = 0.5

    args['a_ts'] = 0.0
    args['a_da'] = 0.0
    args['a_ext'] = 0.0
    args['a_cov'] = 0.0

    args['N'] = 10
    args['lambda'] = 0.0

    args[
        'model_save_dir'] = "/home/alta/summary/pm574/summariser1/lib/trained_models2/"
    args[
        'load_model'] = "/home/alta/summary/pm574/summariser1/lib/trained_models2/model-HGRUV5_FEB25A-ep8"  # add .pt later
    # args['load_model'] = "/home/alta/summary/pm574/summariser1/lib/trained_models2/model-HGRUV5_CNNDM_FEB26A-ep8-bn2000" # add .pt later
    # args['load_model'] = None
    args['model_name'] = 'MBR_GS_FEB27C'
    # ---------------------------------------------------------------------------------- #
    print_config(args)

    if args['use_gpu']:
        if 'X_SGE_CUDA_DEVICE' in os.environ:  # to run on CUED stack machine
            print('running on the stack... 1 GPU')
            cuda_device = os.environ['X_SGE_CUDA_DEVICE']
            print('X_SGE_CUDA_DEVICE is set to {}'.format(cuda_device))
            os.environ['CUDA_VISIBLE_DEVICES'] = cuda_device
        else:
            print('running locally...')
            os.environ[
                "CUDA_VISIBLE_DEVICES"] = '0'  # choose the device (GPU) here
        device = 'cuda'
    else:
        device = 'cpu'
    print("device = {}".format(device))

    # random seed
    random.seed(args['random_seed'])
    torch.manual_seed(args['random_seed'])
    np.random.seed(args['random_seed'])

    train_data = load_ami_data('train')
    valid_data = load_ami_data('valid')
    # make the training data 100
    # random.shuffle(valid_data)
    # train_data.extend(valid_data[:6])
    # valid_data = valid_data[6:]

    model = EncoderDecoder(args, device=device)
    print(model)

    # to use multiple GPUs
    if torch.cuda.device_count() > 1:
        print("Multiple GPUs: {}".format(torch.cuda.device_count()))
        model = nn.DataParallel(model)

    # Load model if specified (path to pytorch .pt)
    if args['load_model'] != None:
        model_path = args['load_model'] + '.pt'
        if device == 'cuda':
            try:
                model.load_state_dict(torch.load(model_path))
            except RuntimeError:  # need to remove module
                # Main model
                model_state_dict = torch.load(model_path)
                new_model_state_dict = OrderedDict()
                for key in model_state_dict.keys():
                    new_model_state_dict[key.replace(
                        "module.", "")] = model_state_dict[key]
                model.load_state_dict(new_model_state_dict)
        else:
            model.load_state_dict(
                torch.load(model_path, map_location=torch.device('cpu')))
        model.train()
        print("Loaded model from {}".format(args['load_model']))
    else:
        print("Train a new model")

    # Hyperparameters
    BATCH_SIZE = args['batch_size']
    NUM_EPOCHS = args['num_epochs']
    VAL_BATCH_SIZE = args['val_batch_size']
    VAL_STOP_TRAINING = args['val_stop_training']

    # Cross-Entropy
    ce_criterion = nn.NLLLoss(reduction='none')

    # we use two separate optimisers (encoder & decoder)
    optimizer = optim.Adam(model.parameters(),
                           lr=args['lr'],
                           betas=(0.9, 0.999),
                           eps=1e-08,
                           weight_decay=0)
    optimizer.zero_grad()

    # validation losses
    best_val_loss = args['best_val_loss']
    best_epoch = 0
    stop_counter = 0

    training_step = 0

    for epoch in range(NUM_EPOCHS):
        print(
            "======================= Training epoch {} ======================="
            .format(epoch))
        num_train_data = len(train_data)
        num_batches = int(num_train_data / BATCH_SIZE)
        print("num_batches = {}".format(num_batches))

        print("shuffle train data")
        random.shuffle(train_data)

        idx = 0

        for bn in range(num_batches):

            input, u_len, w_len, target, tgt_len, _, _, _ = get_a_batch(
                train_data, idx, BATCH_SIZE, args['num_utterances'],
                args['num_words'], args['summary_length'],
                args['summary_type'], device)

            # decoder target
            decoder_target, decoder_mask = shift_decoder_target(
                target, tgt_len, device, mask_offset=True)
            # LOSS - Minimum Bayes Risk
            if bn % 1 == 0: printout = True
            else: printout = False
            R1, R2, RL = grad_sampling(model,
                                       input,
                                       u_len,
                                       w_len,
                                       decoder_target,
                                       num_samples=args['N'],
                                       lambda1=args['lambda'],
                                       device=device,
                                       printout=printout)

            # LOSS - Croos Entropy
            decoder_target = decoder_target.view(-1)
            decoder_mask = decoder_mask.view(-1)
            decoder_output, _, _, _ = model(input, u_len, w_len, target)
            loss = ce_criterion(decoder_output.view(-1, args['vocab_size']),
                                decoder_target)
            loss = (loss * decoder_mask).sum() / decoder_mask.sum()
            loss.backward()

            # R1, R2, RL = beamsearch_approx(model,
            #                     input, u_len, w_len, decoder_target,
            #                     beam_width=args['N'],
            #                     device=device, printout=printout)

            idx += BATCH_SIZE

            if bn % args['update_nbatches'] == 0:
                # update the gradients
                if args['adjust_lr']:
                    adjust_lr(optimizer, args['initial_lr'],
                              args['decay_rate'], training_step)
                optimizer.step()
                optimizer.zero_grad()
                training_step += args['batch_size'] * args['update_nbatches']

            if bn % 1 == 0:
                print(
                    "[{}] batch {}/{}: loss = {:.5f} | R1={:.2f} | R2={:.2f} | RL = {:.2f}"
                    .format(str(datetime.now()), bn, num_batches, loss,
                            R1 * 100, R2 * 100, RL * 100))
                print(
                    "-----------------------------------------------------------------------------------------"
                )
                sys.stdout.flush()

            if bn % 19 == 0:  # e.g. eval every epoch
                # ---------------- Evaluate the model on validation data ---------------- #
                print("Evaluating the model at epoch {} step {}".format(
                    epoch, bn))
                print("learning_rate = {}".format(
                    optimizer.param_groups[0]['lr']))

                # switch to evaluation mode
                model.eval()

                with torch.no_grad():
                    avg_val_loss = evaluate_greedy(model, valid_data,
                                                   VAL_BATCH_SIZE, args,
                                                   device)
                    # avg_val_loss = evaluate(model, valid_data, VAL_BATCH_SIZE, args, device, use_rouge=True)

                print("avg_val_loss_per_token = {}".format(avg_val_loss))

                # switch to training mode
                model.train()
                # ------------------- Save the model OR Stop training ------------------- #
                if avg_val_loss < best_val_loss:
                    stop_counter = 0
                    best_val_loss = avg_val_loss
                    best_epoch = epoch

                    savepath = args[
                        'model_save_dir'] + "model-{}-ep{}-bn{}.pt".format(
                            args['model_name'], epoch, bn)

                    torch.save(model.state_dict(), savepath)
                    print("Model improved & saved at {}".format(savepath))
                else:
                    print("Model not improved #{}".format(stop_counter))
                    if stop_counter < VAL_STOP_TRAINING:
                        print(
                            "Just continue training ---- no loading old weights"
                        )
                        stop_counter += 1
                    else:
                        print(
                            "Model has not improved for {} times! Stop training."
                            .format(VAL_STOP_TRAINING))
                        return

    print("End of training hierarchical RNN model")
示例#3
0
def train_v5_cnndm():
    print("Start training hierarchical RNN model")
    # ---------------------------------------------------------------------------------- #
    args = {}
    args['use_gpu'] = True
    args['num_utterances'] = 640  # max no. utterance in a meeting
    args['num_words'] = 50  # max no. words in an utterance
    args['summary_length'] = 144  # max no. words in a summary
    args['summary_type'] = 'long'  # long or short summary
    args['vocab_size'] = 30522  # BERT tokenizer
    args['embedding_dim'] = 256  # word embeeding dimension
    args['rnn_hidden_size'] = 512  # RNN hidden size

    args['dropout'] = 0.1
    args['num_layers_enc'] = 2  # in total it's num_layers_enc*2 (word/utt)
    args['num_layers_dec'] = 1

    args['random_seed'] = 78
    # args['a_div']       = 1.0
    args['memory_utt'] = False

    args[
        'model_save_dir'] = "/home/alta/summary/pm574/summariser1/lib/trained_models_spotify/"
    args[
        'load_model'] = "/home/alta/summary/pm574/summariser1/lib/trained_models2/model-HGRUV5_CNNDMDIV_APR14A-ep02.pt"
    args['model_name'] = 'HGRUV5DIV_SPOTIFY_JUNE18_v2'
    # ---------------------------------------------------------------------------------- #
    print_config(args)

    if args['use_gpu']:
        if 'X_SGE_CUDA_DEVICE' in os.environ:  # to run on CUED stack machine
            print('running on the stack... 1 GPU')
            cuda_device = os.environ['X_SGE_CUDA_DEVICE']
            print('X_SGE_CUDA_DEVICE is set to {}'.format(cuda_device))
            os.environ['CUDA_VISIBLE_DEVICES'] = cuda_device
        else:
            print('running locally...')
            os.environ[
                "CUDA_VISIBLE_DEVICES"] = '0,1,2,3'  # choose the device (GPU) here
        device = 'cuda'
    else:
        device = 'cpu'
    print("device = {}".format(device))

    # random seed
    random.seed(args['random_seed'])
    torch.manual_seed(args['random_seed'])
    np.random.seed(args['random_seed'])

    # Data
    podcasts = load_podcast_data(sets=-1)
    batcher = HierBatcher(bert_tokenizer, args, podcasts, device)
    val_podcasts = load_podcast_data(sets=[10])
    val_batcher = HierBatcher(bert_tokenizer, args, val_podcasts, device)

    model = EncoderDecoder(args, device=device)
    # print(model)

    # Load model if specified (path to pytorch .pt)
    state = torch.load(args['load_model'])
    model_state_dict = state['model']
    model.load_state_dict(model_state_dict)
    print("load succesful #1")

    criterion = nn.NLLLoss(reduction='none')

    # we use two separate optimisers (encoder & decoder)
    optimizer = optim.Adam(model.parameters(),
                           lr=2e-20,
                           betas=(0.9, 0.999),
                           eps=1e-08,
                           weight_decay=0)
    optimizer.zero_grad()

    # validation losses
    training_step = 0
    batch_size = 4 * 4
    gradient_accum = 1
    total_step = 1000000
    valid_step = 2000
    best_val_loss = 99999999

    # to use multiple GPUs
    if torch.cuda.device_count() > 1:
        print("Multiple GPUs: {}".format(torch.cuda.device_count()))
        model = nn.DataParallel(model)

    while training_step < total_step:

        # get a batch
        input, u_len, w_len, target, tgt_len = batcher.get_a_batch(batch_size)

        # decoder target
        decoder_target, decoder_mask = shift_decoder_target(target,
                                                            tgt_len,
                                                            device,
                                                            mask_offset=True)
        decoder_target = decoder_target.view(-1)
        decoder_mask = decoder_mask.view(-1)

        # decoder_output, _, _, _, u_attn_scores = model(input, u_len, w_len, target)
        decoder_output = model(input, u_len, w_len, target)

        loss = criterion(decoder_output.view(-1, args['vocab_size']),
                         decoder_target)
        loss = (loss * decoder_mask).sum() / decoder_mask.sum()
        loss.backward()

        # Diversity Loss:
        # if batch_size == 1:
        #     intra_div, inter_div = diverisity_loss(u_attn_scores, decoder_target, u_len, tgt_len)
        #     if inter_div == 0:
        #         loss_div = 0
        #     else:
        #         loss_div = intra_div/inter_div
        # else:
        #     dec_target_i = 0
        #     loss_div     = 0
        #     for bi in range(batch_size):
        #         one_u_attn_scores  = u_attn_scores[bi:bi+1,:,:]
        #         one_decoder_target = decoder_target[dec_target_i:dec_target_i+args['summary_length']]
        #         one_u_len   = u_len[bi:bi+1]
        #         one_tgt_len = tgt_len[bi:bi+1]
        #         intra_div, inter_div = diverisity_loss(one_u_attn_scores, one_decoder_target, one_u_len, one_tgt_len)
        #         if inter_div == 0:
        #             loss_div += 0
        #         else:
        #             loss_div += intra_div/inter_div
        #         dec_target_i += args['summary_length']
        #     loss_div /= batch_size
        #
        # total_loss = loss + args['a_div']*loss_div
        # total_loss.backward()

        if training_step % gradient_accum == 0:
            adjust_lr(optimizer, training_step)
            optimizer.step()
            optimizer.zero_grad()

        if training_step % 1 == 0:
            # print("[{}] step {}/{}: loss = {:.5f} | loss_div = {:.5f}".format(
            #     str(datetime.now()), training_step, total_step, loss, loss_div))
            print("[{}] step {}/{}: loss = {:.5f}".format(
                str(datetime.now()), training_step, total_step, loss))
            sys.stdout.flush()

        if training_step % 10 == 0:
            print(
                "======================== GENERATED SUMMARY ========================"
            )
            print(
                bert_tokenizer.decode(
                    torch.argmax(decoder_output[0],
                                 dim=-1).cpu().numpy()[:tgt_len[0]]))
            print(
                "======================== REFERENCE SUMMARY ========================"
            )
            print(
                bert_tokenizer.decode(
                    decoder_target.view(
                        batch_size,
                        args['summary_length'])[0, :tgt_len[0]].cpu().numpy()))

        if training_step % valid_step == 0:
            # ---------------- Evaluate the model on validation data ---------------- #
            print("Evaluating the model at training step {}".format(
                training_step))
            print("learning_rate = {}".format(optimizer.param_groups[0]['lr']))
            # switch to evaluation mode
            model.eval()
            with torch.no_grad():
                valid_loss = evaluate(model, val_batcher, batch_size, args,
                                      device)
            print("valid_loss = {}".format(valid_loss))
            # switch to training mode
            model.train()
            if valid_loss < best_val_loss:
                stop_counter = 0
                best_val_loss = valid_loss
                print("Model improved".format(stop_counter))
            else:
                stop_counter += 1
                print("Model not improved #{}".format(stop_counter))
                if stop_counter == 3:
                    print("Stop training!")
                    return
            state = {
                'training_step': training_step,
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'best_val_loss': best_val_loss
            }
            savepath = args['model_save_dir'] + "{}-step{}.pt".format(
                args['model_name'], training_step)
            torch.save(state, savepath)
            print("Saved at {}".format(savepath))

        training_step += 1

    print("End of training hierarchical RNN model")
示例#4
0
def train_v5():
    print("Start training hierarchical RNN model")
    # ---------------------------------------------------------------------------------- #
    args = {}
    args['use_gpu'] = True
    args['num_utterances'] = 50  # max no. utterance in a meeting
    args['num_words'] = 32  # max no. words in an utterance
    args['summary_length'] = 144  # max no. words in a summary
    args['summary_type'] = 'long'  # long or short summary
    args['vocab_size'] = 30522  # BERT tokenizer
    args['embedding_dim'] = 256  # word embeeding dimension
    args['rnn_hidden_size'] = 512  # RNN hidden size

    args['dropout'] = 0.1
    args['num_layers_enc'] = 2  # in total it's num_layers_enc*2 (word/utt)
    args['num_layers_dec'] = 1

    args['batch_size'] = 8
    args['update_nbatches'] = 1
    args['num_epochs'] = 20
    args['random_seed'] = 78
    args['best_val_loss'] = 1e+10
    args['val_batch_size'] = 64  # 1 for now --- evaluate ROUGE
    args['val_stop_training'] = 10

    args['adjust_lr'] = True  # if True overwrite the learning rate above
    args['initial_lr'] = 1e-2  # lr = lr_0*step^(-decay_rate)
    args['decay_rate'] = 0.25
    args['label_smoothing'] = 0.1

    args[
        'model_save_dir'] = "/home/alta/summary/pm574/summariser1/lib/trained_models2/"
    args[
        'load_model'] = "/home/alta/summary/pm574/summariser1/lib/trained_models2/model-HGRUV5_CNNDM_FEB26A-ep17-bn0.pt"  # add .pt later
    # args['load_model'] = None
    args['model_name'] = 'HGRUV5_CNNDM_APR1'
    # ---------------------------------------------------------------------------------- #
    print_config(args)

    if args['use_gpu']:
        if 'X_SGE_CUDA_DEVICE' in os.environ:  # to run on CUED stack machine
            print('running on the stack... 1 GPU')
            cuda_device = os.environ['X_SGE_CUDA_DEVICE']
            print('X_SGE_CUDA_DEVICE is set to {}'.format(cuda_device))
            os.environ['CUDA_VISIBLE_DEVICES'] = cuda_device
        else:
            print('running locally...')
            os.environ[
                "CUDA_VISIBLE_DEVICES"] = '1'  # choose the device (GPU) here
        device = 'cuda'
    else:
        device = 'cpu'
    print("device = {}".format(device))

    # random seed
    random.seed(args['random_seed'])
    torch.manual_seed(args['random_seed'])
    np.random.seed(args['random_seed'])

    args[
        'model_data_dir'] = "/home/alta/summary/pm574/summariser0/lib/model_data/"
    args['max_pos_embed'] = 512
    args['max_num_sentences'] = 32
    args['max_summary_length'] = args['summary_length']
    train_data = load_cnndm_data(args, 'trainx', dump=False)
    # train_data = load_cnndm_data(args, 'test', dump=False)
    # print("loaded TEST data")
    valid_data = load_cnndm_data(args, 'valid', dump=False)

    model = EncoderDecoder(args, device=device)
    print(model)

    # Load model if specified (path to pytorch .pt)
    if args['load_model'] != None:
        model_path = args['load_model']
        try:
            model.load_state_dict(torch.load(model_path))
        except RuntimeError:  # need to remove module
            # Main model
            model_state_dict = torch.load(model_path)
            new_model_state_dict = OrderedDict()
            for key in model_state_dict.keys():
                new_model_state_dict[key.replace("module.",
                                                 "")] = model_state_dict[key]
            model.load_state_dict(new_model_state_dict)
        model.train()
        print("Loaded model from {}".format(args['load_model']))
    else:
        print("Train a new model")

    # to use multiple GPUs
    if torch.cuda.device_count() > 1:
        print("Multiple GPUs: {}".format(torch.cuda.device_count()))
        model = nn.DataParallel(model)

    print("Train a new model")

    # Hyperparameters
    BATCH_SIZE = args['batch_size']
    NUM_EPOCHS = args['num_epochs']
    VAL_BATCH_SIZE = args['val_batch_size']
    VAL_STOP_TRAINING = args['val_stop_training']

    if args['label_smoothing'] > 0.0:
        criterion = LabelSmoothingLoss(num_classes=args['vocab_size'],
                                       smoothing=args['label_smoothing'],
                                       reduction='none')
    else:
        criterion = nn.NLLLoss(reduction='none')

    # we use two separate optimisers (encoder & decoder)
    optimizer = optim.Adam(model.parameters(),
                           lr=0.77,
                           betas=(0.9, 0.999),
                           eps=1e-08,
                           weight_decay=0)
    optimizer.zero_grad()

    # validation losses
    best_val_loss = args['best_val_loss']
    best_epoch = 0
    stop_counter = 0

    training_step = 0

    for epoch in range(NUM_EPOCHS):
        print(
            "======================= Training epoch {} ======================="
            .format(epoch))
        num_train_data = len(train_data)
        # num_batches = int(num_train_data/BATCH_SIZE) + 1
        num_batches = int(num_train_data / BATCH_SIZE)
        print("num_batches = {}".format(num_batches))

        print("shuffle train data")
        random.shuffle(train_data)

        idx = 0

        for bn in range(num_batches):

            input, u_len, w_len, target, tgt_len = get_a_batch(
                train_data, idx, BATCH_SIZE, args['num_utterances'],
                args['num_words'], args['summary_length'],
                args['summary_type'], device)

            # decoder target
            decoder_target, decoder_mask = shift_decoder_target(
                target, tgt_len, device, mask_offset=True)
            decoder_target = decoder_target.view(-1)
            decoder_mask = decoder_mask.view(-1)

            try:
                # decoder_output = model(input, u_len, w_len, target)
                decoder_output, _, attn_scores, _, u_attn_scores = model(
                    input, u_len, w_len, target)
                import pdb
                pdb.set_trace()
            except IndexError:
                print(
                    "there is an IndexError --- likely from if segment_indices[bn][-1] == u_len[bn]-1:"
                )
                print("for now just skip this batch!")
                idx += BATCH_SIZE  # previously I forget to add this line!!!
                continue

            loss = criterion(decoder_output.view(-1, args['vocab_size']),
                             decoder_target)
            loss = (loss * decoder_mask).sum() / decoder_mask.sum()

            loss.backward()

            idx += BATCH_SIZE

            if bn % args['update_nbatches'] == 0:
                # gradient_clipping
                max_norm = 0.5
                nn.utils.clip_grad_norm_(model.parameters(), max_norm)
                # update the gradients
                if args['adjust_lr']:
                    adjust_lr(optimizer, args['initial_lr'],
                              args['decay_rate'], training_step)
                optimizer.step()
                optimizer.zero_grad()
                training_step += args['batch_size'] * args['update_nbatches']

            if bn % 1 == 0:
                print("[{}] batch {}/{}: loss = {:5f}".format(
                    str(datetime.now()), bn, num_batches, loss))
                sys.stdout.flush()

            if bn % 50 == 0:

                print(
                    "======================== GENERATED SUMMARY ========================"
                )
                print(
                    bert_tokenizer.decode(
                        torch.argmax(decoder_output[0],
                                     dim=-1).cpu().numpy()[:tgt_len[0]]))
                print(
                    "======================== REFERENCE SUMMARY ========================"
                )
                print(
                    bert_tokenizer.decode(
                        decoder_target.view(BATCH_SIZE, args['summary_length'])
                        [0, :tgt_len[0]].cpu().numpy()))

            if bn % 500 == 0:
                # ---------------- Evaluate the model on validation data ---------------- #
                print("Evaluating the model at epoch {} step {}".format(
                    epoch, bn))
                print("learning_rate = {}".format(
                    optimizer.param_groups[0]['lr']))
                # switch to evaluation mode
                model.eval()

                with torch.no_grad():
                    avg_val_loss = evaluate(model, valid_data, VAL_BATCH_SIZE,
                                            args, device)

                print("avg_val_loss_per_token = {}".format(avg_val_loss))

                # switch to training mode
                model.train()
                # ------------------- Save the model OR Stop training ------------------- #
                if avg_val_loss < best_val_loss:
                    stop_counter = 0
                    best_val_loss = avg_val_loss
                    best_epoch = epoch

                    savepath = args[
                        'model_save_dir'] + "model-{}-ep{}.pt".format(
                            args['model_name'], epoch)
                    torch.save(model.state_dict(), savepath)
                    print("Model improved & saved at {}".format(savepath))
                else:
                    print("Model not improved #{}".format(stop_counter))
                    if stop_counter < VAL_STOP_TRAINING:
                        print(
                            "Just continue training ---- no loading old weights"
                        )
                        stop_counter += 1
                    else:
                        print(
                            "Model has not improved for {} times! Stop training."
                            .format(VAL_STOP_TRAINING))
                        return

    print("End of training hierarchical RNN model")