示例#1
0
def train_v5(start_idx, end_idx):
    print("Start training hierarchical RNN model")
    # ---------------------------------------------------------------------------------- #
    args = {}
    args['use_gpu'] = True
    args['num_utterances'] = 1500  # max no. utterance in a meeting
    args['num_words'] = 64  # max no. words in an utterance
    args['summary_length'] = 300  # max no. words in a summary
    args['summary_type'] = 'short'  # long or short summary
    args['vocab_size'] = 30522  # BERT tokenizer
    args['embedding_dim'] = 256  # word embeeding dimension
    args['rnn_hidden_size'] = 512  # RNN hidden size

    args['dropout'] = 0.1
    args['num_layers_enc'] = 2  # in total it's num_layers_enc*2 (word/utt)
    args['num_layers_dec'] = 1

    args['batch_size'] = 1

    args['memory_utt'] = False

    args[
        'model_save_dir'] = "/home/alta/summary/pm574/summariser1/lib/trained_models2/"
    args[
        'load_model'] = "/home/alta/summary/pm574/summariser1/lib/trained_models2/model-HGRUV5_APR16C2-ep0.pt"
    load_option = 2  # 1=old | 2=new
    # ---------------------------------------------------------------------------------- #
    # print_config(args)

    if args['use_gpu']:
        if 'X_SGE_CUDA_DEVICE' in os.environ:  # to run on CUED stack machine
            print('running on the stack... 1 GPU')
            cuda_device = os.environ['X_SGE_CUDA_DEVICE']
            print('X_SGE_CUDA_DEVICE is set to {}'.format(cuda_device))
            os.environ['CUDA_VISIBLE_DEVICES'] = cuda_device
        else:
            print('running locally...')
            os.environ[
                "CUDA_VISIBLE_DEVICES"] = '0'  # choose the device (GPU) here
        device = 'cuda'
    else:
        device = 'cpu'
    print("device = {}".format(device))

    train_data = load_ami_data('test')

    model = EncoderDecoder(args, device=device)
    # print(model)

    # Load model if specified (path to pytorch .pt)
    if args['load_model'] != None:
        trained_model = args['load_model']
        if device == 'cuda':
            try:
                state = torch.load(trained_model)
                if load_option == 1:
                    model.load_state_dict(state)
                elif load_option == 2:
                    model_state_dict = state['model']
                    model.load_state_dict(model_state_dict)
            except RuntimeError:  # need to remove module
                # Main model
                model_state_dict = torch.load(trained_model)
                new_model_state_dict = OrderedDict()
                for key in model_state_dict.keys():
                    new_model_state_dict[key.replace(
                        "module.", "")] = model_state_dict[key]
                model.load_state_dict(new_model_state_dict)
        else:
            try:
                state = torch.load(trained_model,
                                   map_location=torch.device('cpu'))
                if load_option == 1:
                    model.load_state_dict(state)
                elif load_option == 2:
                    model_state_dict = state['model']
                    model.load_state_dict(model_state_dict)
            except:
                model_state_dict = torch.load(trained_model,
                                              map_location=torch.device('cpu'))
                new_model_state_dict = OrderedDict()
                for key in model_state_dict.keys():
                    new_model_state_dict[key.replace(
                        "module.", "")] = model_state_dict[key]
                model.load_state_dict(new_model_state_dict)
        model.eval()
        print("Loaded model from {}".format(args['load_model']))
    else:
        print("Train a new model")

    print("Train a new model")

    # Hyperparameters
    BATCH_SIZE = args['batch_size']
    if BATCH_SIZE != 1: raise ValueError("Batch Size must be 1")

    num_train_data = len(train_data)
    num_batches = int(num_train_data / BATCH_SIZE)
    print("num_batches = {}".format(num_batches))

    idx = 0

    decode_dict = {
        'batch_size': BATCH_SIZE,
        'k': 10,
        'search_method': 'argmax',
        'time_step': args['summary_length'],
        'vocab_size': 30522,
        'device': device,
        'start_token_id': START_TOKEN_ID,
        'stop_token_id': STOP_TOKEN_ID,
        'alpha': 2.0,
        'length_offset': 5,
        'penalty_ug': 0.0,
        'keypadmask_dtype': KEYPADMASK_DTYPE,
        'memory_utt': args['memory_utt']
    }

    print("DECODING: {}".format(decoding_method))

    with torch.no_grad():
        for bn in range(start_idx, end_idx):

            if check_if_id_exists(bn): continue

            input, u_len, w_len, target, tgt_len = get_a_batch(
                train_data, bn, BATCH_SIZE, args['num_utterances'],
                args['num_words'], args['summary_length'],
                args['summary_type'], device)

            # decoder target
            decoder_target, decoder_mask = shift_decoder_target(
                target, tgt_len, device, mask_offset=True)
            decoder_target = decoder_target.view(-1)
            decoder_mask = decoder_mask.view(-1)

            if teacherforcing == True:
                try:
                    # decoder_output = model(input, u_len, w_len, target)
                    decoder_output, _, attn_scores, _, u_attn_scores = model(
                        input, u_len, w_len, target)
                except IndexError:
                    print(
                        "there is an IndexError --- likely from if segment_indices[bn][-1] == u_len[bn]-1:"
                    )
                    print("for now just skip this batch!")
                    idx += BATCH_SIZE  # previously I forget to add this line!!!
                    continue

                output = torch.argmax(decoder_output,
                                      dim=-1).cpu().numpy().tolist()[0]
                max_l = decoder_output.size(1)

            else:
                output, attn_score, attn_score_u = model.decode_beamsearch(
                    input, u_len, w_len, decode_dict)

                # shift decoder_output by one
                output = output.tolist()
                max_l = len(output)
                u_attn_scores = attn_score_u.unsqueeze(0)

            try:
                dec_len = output.index(103)
            except ValueError:
                dec_len = max_l
            dec_sep_pos = []
            for i, v in enumerate(output):
                if i == dec_len: break
                if v == 102: dec_sep_pos.append(i)

            if len(dec_sep_pos) == 0:
                dec_sep_pos.append(max_l)

            enc_len = u_len[0]
            dec_start_pos = [0] + [x + 1 for x in dec_sep_pos[:-1]]
            this_attn = u_attn_scores[0, :dec_len, :enc_len].cpu()

            mean_div_within_sentence, mean_div_between_sentences = diversity1_sent(
                this_attn, dec_start_pos, dec_sep_pos)
            write_attn_scores(bn, mean_div_within_sentence,
                              mean_div_between_sentences)
示例#2
0
class HierarchicalModel(object):
    def __init__(self, model_name, model_step=None, use_gpu=False):
        # if model_name not in ["HIER", "HIERDIV", "AMI_MT_DIV", "SPOTIFY_short", "SPOTIFY_long"]:
        #     raise ValueError("model name not exist")

        if model_name == "HIER":
            self.model_path = "/home/alta/summary/pm574/summariser1/lib/trained_models2/model-HGRUV5_CNNDM_FEB26A-ep17.pt"
            self.load_option = 1
        elif model_name == "HIERDIV":
            self.model_path = "/home/alta/summary/pm574/summariser1/lib/trained_models2/model-HGRUV5_CNNDMDIV_APR14A-ep02.pt"
            self.load_option = 2
        elif model_name == "AMI_MT_DIV":
            self.model_path = "/home/alta/summary/pm574/summariser1/lib/trained_models2/model-HGRUV5_APR12ALL100-ep20.pt"
            self.load_option = 1

        elif model_name == "SPOTIFY_short":
            self.model_path = "/home/alta/summary/pm574/summariser1/lib/trained_models_spotify/HGRUV5DIV_SPOTIFY_JUNE18_v4-step{}.pt".format(
                model_step)
            self.load_option = 2

        elif model_name == "SPOTIFY_long":
            self.model_path = "/home/alta/summary/pm574/summariser1/lib/trained_models_spotify/HGRUV5DIV_SPOTIFY_JUNE18_v3-step{}.pt".format(
                model_step)
            self.load_option = 2

        else:
            self.model_path = model_name
            self.load_option = 2  # new version

        args = {}
        args['vocab_size'] = 30522  # BERT tokenizer
        args['embedding_dim'] = 256  # word embeeding dimension
        args['rnn_hidden_size'] = 512  # RNN hidden size
        args['dropout'] = 0.0
        args['num_layers_enc'] = 2
        args['num_layers_dec'] = 1
        args['memory_utt'] = False

        self.device = 'cuda' if use_gpu else 'cpu'
        self.model = EncoderDecoder(args, self.device)
        self.load_model()
        self.model.eval()

    def load_model(self):
        if self.device == 'cuda':
            try:
                state = torch.load(self.model_path)
                if self.load_option == 1:
                    self.model.load_state_dict(state)
                elif self.load_option == 2:
                    model_state_dict = state['model']
                    self.model.load_state_dict(model_state_dict)
                print("load succesful #1")
            except:
                if self.load_option == 1:
                    model_state_dict = torch.load(self.model_path)
                new_model_state_dict = OrderedDict()
                for key in model_state_dict.keys():
                    new_model_state_dict[key.replace(
                        "module.", "")] = model_state_dict[key]
                self.model.load_state_dict(new_model_state_dict)
                print("load succesful #2")
        elif self.device == 'cpu':
            try:
                state = torch.load(self.model_path,
                                   map_location=torch.device('cpu'))
                if self.load_option == 1:
                    self.model.load_state_dict(state)
                elif self.load_option == 2:
                    model_state_dict = state['model']
                    self.model.load_state_dict(model_state_dict)
                print("load succesful #3")
            except:
                if self.load_option == 1:
                    model_state_dict = torch.load(
                        self.model_path, map_location=torch.device('cpu'))

                new_model_state_dict = OrderedDict()
                for key in model_state_dict.keys():
                    new_model_state_dict[key.replace(
                        "module.", "")] = model_state_dict[key]
                self.model.load_state_dict(new_model_state_dict)
                print("load succesful #4")

    def decode(self,
               tokenizer,
               batches,
               beam_width=4,
               time_step=144,
               penalty_ug=0.0,
               alpha=1.25,
               length_offset=5):
        decode_dict = {
            'k': beam_width,
            'time_step': time_step,
            'vocab_size': 30522,
            'device': self.device,
            'start_token_id': 101,
            'stop_token_id': 103,
            'alpha': alpha,
            'length_offset': length_offset,
            'penalty_ug': penalty_ug,
            'keypadmask_dtype': KEYPADMASK_DTYPE,
            'memory_utt': False,
            'batch_size': 1
        }
        summaries = [None for _ in range(len(batches))]
        for i, batch in enumerate(batches):
            summary_id = self.beam_search(batch, decode_dict)
            sentences = tokenizer.tgtids2summary(summary_id)
            summaries[i] = " ".join(sentences)
        return summaries

    def beam_search(self, batch, decode_dict):
        input = batch.input
        u_len = batch.u_len
        w_len = batch.w_len
        with torch.no_grad():
            summary_id, _, _ = self.model.decode_beamsearch(
                input, u_len, w_len, decode_dict)
        return summary_id

    def get_utt_attn_with_ref(self, enc_batch, target, tgt_len):
        # batch_size should be 1
        with torch.no_grad():
            # Teacher Forcing
            _, _, _, _, u_attn_scores = self.model(enc_batch.input,
                                                   enc_batch.u_len,
                                                   enc_batch.w_len, target)
        N = enc_batch.u_len[0].item()
        T = tgt_len[0].item()
        attention = u_attn_scores[0, :T, :N].sum(
            dim=0) / u_attn_scores[0, :T, :N].sum()
        attention = attention.cpu().numpy()
        return attention

    def get_utt_attn_without_ref(self,
                                 enc_batch,
                                 beam_width=4,
                                 time_step=144,
                                 penalty_ug=0.0,
                                 alpha=1.25,
                                 length_offset=5):
        decode_dict = {
            'k': beam_width,
            'time_step': time_step,
            'vocab_size': 30522,
            'device': self.device,
            'start_token_id': 101,
            'stop_token_id': 103,
            'alpha': alpha,
            'length_offset': length_offset,
            'penalty_ug': penalty_ug,
            'keypadmask_dtype': KEYPADMASK_DTYPE,
            'memory_utt': False,
            'batch_size': 1
        }
        # batch_size should be 1
        with torch.no_grad():

            summary_ids, attn_scores, u_attn_scores = self.model.decode_beamsearch(
                enc_batch.input, enc_batch.u_len, enc_batch.w_len, decode_dict)

        N = enc_batch.u_len[0].item()
        attention = u_attn_scores[:, :N].sum(
            dim=0) / u_attn_scores[:, :N].sum()
        attention = attention.cpu().numpy()
        return attention