Example #1
0
def main():

    # load the data
    #en_es_train_dev, es_en_train, es_en_dev, es_en_test, mappings, u1, c1 = get_en_es_data(0,0)

    #for more datasets, uncomment the following two lines
    #es_en_train_dev, es_en_train, es_en_dev, es_en_test, mappings, u2, c2 = get_es_en_data(0, 0)
    fr_en_train_dev, fr_en_train, fr_en_dev, fr_en_test, mappings, u3, c3 = get_fr_en_data(0, 0)


    
    # convert the data to token to idx
    all_tokens = np.array(list(fr_en_train[0]) + list(fr_en_dev[0]) + list(fr_en_test[0]))
    token_to_idx = create_token_to_idx(all_tokens)

    # original code line: train_user_idx = prepare_data(es_en_train[0], es_en_train[1], token_to_idx, user_to_idx)
    # code now: split to processing tokens and importing the already processed metadata
    # the part for metadata originally had a shape of (num_of_exericses, MAX_TOKEN_SIZE) but it's now (num_of_exericses, MAX_TOKEN_SIZE, num_of_features)
    train_sentence_idx = process_sentence(fr_en_train_dev[0], token_to_idx)
    train_metadata = fr_en_train_dev[1]

    instance_id_to_dict = fr_en_train_dev[3]
    # convert true_labels to idx
    labels_array = np.zeros((len(fr_en_train_dev[0]), MAX_TOKEN_SIZE))
    for i in range(len(labels_array)):
        idx = np.array([instance_id_to_dict[i_id] for i_id in fr_en_train_dev[2][i]] + \
                       [0] * (MAX_TOKEN_SIZE - len(fr_en_train_dev[2][i])))
        labels_array[i] = idx

    model = EncoderDecoder(len(token_to_idx), mappings, 100, 100, 4, 100, 100)
    for j in range(10):
        print("Epoch ", j+1)
        total_loss = 0
        for i in tqdm(range(0, len(train_sentence_idx), BATCH_SIZE)):
            x_batch = train_sentence_idx[i: i+BATCH_SIZE]
            y_batch = labels_array[i: i+BATCH_SIZE]
            x_metadata_batch = train_metadata[i: i+BATCH_SIZE]

            mask = create_padding_mask(x_batch)
            with tf.GradientTape() as tape:
                logits = model.call(x_batch, x_metadata_batch, mask, training=True)
                loss = model.loss_function(logits, y_batch, mask)
                total_loss += loss.numpy()
            gradients = tape.gradient(loss, model.trainable_variables)
            model.optimizer.apply_gradients(zip(gradients, model.trainable_variables))
        print("Avg batch loss", total_loss/i+1)

            # if i == 40:
            # break
        
        # print("====Dev ====")
        # flattened_instance_ids, actual, preds = predict(model, es_en_dev, token_to_idx)
        print("====Test====")
        flattened_instance_ids, actual, preds = predict(model, fr_en_test, token_to_idx)
Example #2
0
    def __init__(self, config: HiDDenConfiguration, device: torch.device):
        self.enc_dec = EncoderDecoder(config).to(device)
        self.discr = Discriminator(config).to(device)
        self.opt_enc_dec = torch.optim.Adam(self.enc_dec.parameters())
        self.opt_discr = torch.optim.Adam(self.discr.parameters())

        self.config = config
        self.device = device
        self.bce_with_logits_loss = nn.BCEWithLogitsLoss().to(device)
        self.mse_loss = nn.MSELoss().to(device)

        self.cover_label = 1
        self.encod_label = 0
Example #3
0
    def __init__(self, configuration: HiDDenConfiguration,
                 device: torch.device, noiser: Noiser, tb_logger):
        """
        :param configuration: Configuration for the net, such as the size of the input image, number of channels in the intermediate layers, etc.
        :param device: torch.device object, CPU or GPU
        :param noiser: Object representing stacked noise layers.
        :param tb_logger: Optional TensorboardX logger object, if specified -- enables Tensorboard logging
        """
        super(Hidden, self).__init__()

        self.encoder_decoder = EncoderDecoder(configuration, noiser).to(device)
        self.discriminator = Discriminator(configuration).to(device)

        self.optimizer_enc_dec = torch.optim.Adam(
            self.encoder_decoder.parameters())
        self.optimizer_discrim = torch.optim.Adam(
            self.discriminator.parameters())

        if configuration.use_vgg:
            self.vgg_loss = VGGLoss(3, 1, False)
            self.vgg_loss.to(device)
        else:
            self.vgg_loss = None

        self.config = configuration
        self.device = device

        self.bce_with_logits_loss = nn.BCEWithLogitsLoss().to(device)
        self.mse_loss = nn.MSELoss().to(device)
        self.ce_loss = nn.CrossEntropyLoss().to(device)

        # Defined the labels used for training the discriminator/adversarial loss
        self.cover_label = 1
        self.encoded_label = 0

        self.tb_logger = tb_logger
        if tb_logger is not None:
            from tensorboard_logger import TensorBoardLogger
            #print(self.encoder_decoder.encoder._modules['module'].final_layer)
            encoder_final = self.encoder_decoder.encoder._modules[
                'module'].final_layer
            encoder_final.weight.register_hook(
                tb_logger.grad_hook_by_name('grads/encoder_out'))
            decoder_final = self.encoder_decoder.decoder._modules[
                'module'].linear
            decoder_final.weight.register_hook(
                tb_logger.grad_hook_by_name('grads/decoder_out'))
            #print(self.discriminator._modules)
            discrim_final = self.discriminator._modules['linear']
            discrim_final.weight.register_hook(
                tb_logger.grad_hook_by_name('grads/discrim_out'))
Example #4
0
def main():
    #for more datasets, uncomment the following two lines
    es_en_train_dev, es_en_train, es_en_dev, es_en_test, mappings, u2, c2 = get_es_en_data(
        0, 0)

    # convert the data to token to idx
    all_tokens = np.array(
        list(es_en_train[0]) + list(es_en_dev[0]) + list(es_en_test[0]))
    token_to_idx = create_token_to_idx(all_tokens)

    train_sentence_idx = process_sentence(es_en_train_dev[0], token_to_idx)
    train_metadata = es_en_train_dev[1]

    instance_id_to_dict = es_en_train_dev[3]
    # convert true_labels to idx
    labels_array = np.zeros((len(es_en_train_dev[0]), MAX_TOKEN_SIZE))
    for i in range(len(labels_array)):
        idx = np.array([instance_id_to_dict[i_id] for i_id in es_en_train_dev[2][i]] + \
                       [0] * (MAX_TOKEN_SIZE - len(es_en_train_dev[2][i])))
        labels_array[i] = idx

    model = EncoderDecoder(len(token_to_idx), mappings, 100, 100, 4, 100, 100)
    for j in range(10):
        print("Epoch ", j + 1)
        total_loss = 0
        for i in tqdm(range(0, len(train_sentence_idx), BATCH_SIZE)):
            x_batch = train_sentence_idx[i:i + BATCH_SIZE]
            y_batch = labels_array[i:i + BATCH_SIZE]
            x_metadata_batch = train_metadata[i:i + BATCH_SIZE]

            mask = create_padding_mask(x_batch)
            with tf.GradientTape() as tape:
                logits = model.call(x_batch,
                                    x_metadata_batch,
                                    mask,
                                    training=True)
                loss = model.loss_function(logits, y_batch, mask)
                total_loss += loss.numpy()
            gradients = tape.gradient(loss, model.trainable_variables)
            model.optimizer.apply_gradients(
                zip(gradients, model.trainable_variables))
        print("Avg batch loss", total_loss / i + 1)

        print("====Test====")
        flattened_instance_ids, actual, preds = predict(
            model, es_en_test, token_to_idx)
Example #5
0
def exec_complexity(args):
    x = torch.randn(1, args.in_channels, args.block_size,
                    args.block_size).to(args.device)
    m = torch.randn(1, args.message_length).to(args.device)
    noiser = Noiser('', device=args.device)
    if args.arch == 'hidden':
        model = EncoderDecoder(args.hidden_config, noiser).to(args.device)
    else:
        model = nets.EncoderDecoder(args.block_size, args.message_length,
                                    noiser, args.in_channels,
                                    args.layers).to(args.device)
    flops, params = thop.profile(model, inputs=(x, m))
    print(f'FLOPs = {flops / 1000 ** 3:.6f}G')
    print(f'Params = {params / 1000 ** 2:.6f}M')
Example #6
0
def get_bleu(encoder_decoder: EncoderDecoder, data_path, model_name,
             data_type):
    # test_file = open("data/copynet_test.txt", "r", encoding='utf-8')
    test_file = open(data_path + 'copynet_' + data_type + '.txt',
                     'r',
                     encoding='utf-8')
    if (data_type != 'dev'):
        out_file = open("results/" + model_name.split('/')[-1] + ".txt",
                        'w',
                        encoding='utf-8')
    total_score = 0.0
    num = 0.0
    for i, row in enumerate(tqdm(test_file)):
        sql = row.split('\t')[1]
        gold_nl = row.split('\t')[0]
        predicted = encoder_decoder.get_response(sql)
        predicted = predicted.replace('<SOS>', '')
        predicted = predicted.replace('<EOS>', '')
        predicted = predicted.rstrip()
        if (data_type != 'dev'):
            out_file.write(predicted + "\n")

        # score = sentence_bleu([gold_nl.split()], predicted.split(), smoothing_function=SmoothingFunction().method2)
        score = sentence_bleu([gold_nl.split()], predicted.split())
        # score = sentence_bleu(ref, pred)
        total_score += score
        num += 1
        '''
        if i == 10:
            break
        '''

    # del encoder_decoder
    test_file.close()
    if (data_type != 'dev'):
        out_file.close()
    final_score = total_score * 100 / num
    if (data_type == 'dev'):
        print("DEV set")
    else:
        print("TEST set")
    print("BLEU score is " + str(final_score))
    return final_score
Example #7
0
def main():

    # load the data
    en_es_train_dev, es_en_train, es_en_dev, es_en_test, mappings1, u1, c1 = get_en_es_data(0,0)

    #for more datasets, uncomment the following two lines
    es_en_train_dev, es_en_train, es_en_dev, es_en_test, mappings2, u2, c2 = get_es_en_data(u1, c1)
    fr_en_train_dev, fr_en_train, fr_en_dev, fr_en_test, mappings3, u3, c3 = get_fr_en_data(u2, c2)

    ##combine train_dev for all three datasets
    #get each attributes
    en_es_sentence, en_es_meta, en_es_inst, en_es_label = en_es_train_dev
    es_en_sentence, es_en_meta, es_en_inst, es_en_label = es_en_train_dev
    fr_en_sentence, fr_en_meta, fr_en_inst, fr_en_label = fr_en_train_dev

    #concatenate
    print(en_es_sentence.shape)
    print(es_en_sentence.shape)
    print(fr_en_sentence.shape)

    print(en_es_meta.shape)
    print(es_en_meta.shape)
    print(fr_en_meta.shape)

    print(en_es_inst.shape)
    print(es_en_inst.shape)
    print(fr_en_inst.shape)
    

    combined_sentence = np.concatenate((en_es_sentence, es_en_sentence, fr_en_sentence), axis=0)
    combined_meta = np.concatenate((en_es_meta, es_en_meta, fr_en_meta), axis=0)
    combined_inst = np.concatenate((en_es_inst, es_en_inst, fr_en_inst), axis=0)
    #combine labels
    combined_labels = copy.deepcopy(en_es_label)
    combined_labels.update(es_en_label) #add es_en to dict
    combined_labels.update(fr_en_label) #add fr_en to dict

    index = np.random.permutation(combined_sentence.shape[0])
    shuffled_combined_sentence = combined_sentence[index]
    shuffled_combined_meta = combined_meta[index]
    shuffled_combined_inst = combined_inst[index]

    combined_train_dev = (shuffled_combined_sentence, shuffled_combined_meta, shuffled_combined_inst, combined_labels)

    #combine mappings1,mappings2,mappings3
    usid1, ctid1, clt1, sessid1, fmatid1, speechid1, dep1, morph1 = mappings1
    usid2, ctid2, clt2, sessid2, fmatid2, speechid2, dep2, morph2 = mappings2
    usid3, ctid3, clt3, sessid3, fmatid3, speechid3, dep3, morph3 = mappings3
    usid = combine_dicts(usid1, usid2, usid3)
    ctid = combine_dicts(ctid, ctid2, ctid3)
    clt = combine_dicts(clt1, clt2, clt3)
    sess = combine_dicts(sessid1, sessid2, sessid3)
    fmat = combine_dicts(fmatid1, fmatid2, fmatid3)
    speech = combine_dicts(speechid1, speechid2, speechid3)
    dep = combine_dicts(dep1, dep2, dep3)
    morph = combine_dicts(morph1, morph2, morph3)
    combined_mappings = (usid, ctid, clt, sess, fmat, speech, dep, morph)
    



    # convert the data to token to idx
    all_tokens = np.array(list(es_en_train[0]) + list(es_en_dev[0]) + list(es_en_test[0]) + \
                          list(en_es_train[0]) + list(en_es_dev[0]) + list(en_es_test[0]) + \
                          list(fr_en_train[0]) + list(fr_en_dev[0]) + list(fr_en_test[0]))
    token_to_idx = create_token_to_idx(all_tokens)

    # TODO: 
    # original code line: train_user_idx = prepare_data(es_en_train[0], es_en_train[1], token_to_idx, user_to_idx)
    # code now: split to processing tokens and importing the already processed metadata
    # the part for metadata originally had a shape of (num_of_exericses, MAX_TOKEN_SIZE) but it's now (num_of_exericses, MAX_TOKEN_SIZE, num_of_features)
    train_sentence_idx = process_sentence(combined_train_dev[0], token_to_idx)
    train_metadata = combined_train_dev[1]

    instance_id_to_dict = combined_train_dev[3]
    # convert true_labels to idx
    labels_array = np.zeros((len(combined_train_dev[0]), MAX_TOKEN_SIZE))
    for i in range(len(labels_array)):
        idx = np.array([instance_id_to_dict[i_id] for i_id in combined_train_dev[2][i]] + \
                       [0] * (MAX_TOKEN_SIZE - len(combined_train_dev[2][i])))
        labels_array[i] = idx

    model = EncoderDecoder(len(token_to_idx), combined_mappings, 300, 300, 4, 100, 100)
    for j in range(10):
        print("Epoch ", j+1)
        # TODO: shuffle training data
        total_loss = 0
        for i in tqdm(range(0, len(train_sentence_idx)/50, BATCH_SIZE)):
            x_batch = train_sentence_idx[i: i+BATCH_SIZE]
            y_batch = labels_array[i: i+BATCH_SIZE]

            x_metadata_batch = train_metadata[i: i+BATCH_SIZE]

            mask = create_padding_mask(x_batch)
            with tf.GradientTape() as tape:
                logits = model.call(x_batch, x_metadata_batch, mask, training=True)
                loss = model.loss_function(logits, y_batch, mask)
                total_loss += loss.numpy()
            gradients = tape.gradient(loss, model.trainable_variables)
            model.optimizer.apply_gradients(zip(gradients, model.trainable_variables))
        print("Avg batch loss", total_loss/(i+1))

            # if i == 40:
            # break
        
        # print("====Dev ====")
        # flattened_instance_ids, actual, preds = predict(model, es_en_dev, token_to_idx)
        print("====Test====")
        flattened_instance_ids1, actual1, preds1 = predict(model, es_en_test, token_to_idx)
        flattened_instance_ids2, actual2, preds2 = predict(model, en_es_test, token_to_idx)
        flattened_instance_ids3, actual3, preds3 = predict(model, fr_en_test, token_to_idx)
Example #8
0
class Hidden:
    def __init__(self, configuration: HiDDenConfiguration,
                 device: torch.device, noiser: Noiser, tb_logger):
        """
        :param configuration: Configuration for the net, such as the size of the input image, number of channels in the intermediate layers, etc.
        :param device: torch.device object, CPU or GPU
        :param noiser: Object representing stacked noise layers.
        :param tb_logger: Optional TensorboardX logger object, if specified -- enables Tensorboard logging
        """
        super(Hidden, self).__init__()

        self.encoder_decoder = EncoderDecoder(configuration, noiser).to(device)
        self.optimizer_enc_dec = torch.optim.Adam(
            self.encoder_decoder.parameters())

        self.discriminator = Discriminator(configuration).to(device)
        self.optimizer_discrim = torch.optim.Adam(
            self.discriminator.parameters())

        if configuration.use_vgg:
            self.vgg_loss = VGGLoss(3, 1, False)
            self.vgg_loss.to(device)
        else:
            self.vgg_loss = None

        self.config = configuration
        self.device = device

        self.bce_with_logits_loss = nn.BCEWithLogitsLoss()
        self.mse_loss = nn.MSELoss()

        # Defined the labels used for training the discriminator/adversarial loss
        self.cover_label = 1
        self.encoded_label = 0

        self.tb_logger = tb_logger
        if tb_logger is not None:
            from tensorboard_logger import TensorBoardLogger
            encoder_final = self.encoder_decoder.encoder._modules[
                'final_layer']
            encoder_final.weight.register_hook(
                tb_logger.grad_hook_by_name('grads/encoder_out'))
            decoder_final = self.encoder_decoder.decoder._modules['linear']
            decoder_final.weight.register_hook(
                tb_logger.grad_hook_by_name('grads/decoder_out'))
            discrim_final = self.discriminator._modules['linear']
            discrim_final.weight.register_hook(
                tb_logger.grad_hook_by_name('grads/discrim_out'))

    def train_on_batch(self, batch: list):
        """
        Trains the network on a single batch consisting of images and messages
        :param batch: batch of training data, in the form [images, messages]
        :return: dictionary of error metrics from Encoder, Decoder, and Discriminator on the current batch
        """
        images, messages = batch
        batch_size = images.shape[0]
        with torch.enable_grad():
            # ---------------- Train the discriminator -----------------------------
            self.optimizer_discrim.zero_grad()
            # train on cover
            d_target_label_cover = torch.full((batch_size, 1),
                                              self.cover_label,
                                              device=self.device)
            d_on_cover = self.discriminator(images)
            d_loss_on_cover = self.bce_with_logits_loss(
                d_on_cover, d_target_label_cover)
            d_loss_on_cover.backward()

            # train on fake
            encoded_images, noised_images, decoded_messages = self.encoder_decoder(
                images, messages)
            d_target_label_encoded = torch.full((batch_size, 1),
                                                self.encoded_label,
                                                device=self.device)
            d_on_encoded = self.discriminator(encoded_images.detach())
            d_loss_on_encoded = self.bce_with_logits_loss(
                d_on_encoded, d_target_label_encoded)
            d_loss_on_encoded.backward()
            self.optimizer_discrim.step()

            # --------------Train the generator (encoder-decoder) ---------------------
            self.optimizer_enc_dec.zero_grad()
            # target label for encoded images should be 'cover', because we want to fool the discriminator
            g_target_label_encoded = torch.full((batch_size, 1),
                                                self.cover_label,
                                                device=self.device)
            d_on_encoded_for_enc = self.discriminator(encoded_images)
            g_loss_adv = self.bce_with_logits_loss(d_on_encoded_for_enc,
                                                   g_target_label_encoded)

            if self.vgg_loss == None:
                g_loss_enc = self.mse_loss(encoded_images, images)
            else:
                vgg_on_cov = self.vgg_loss(images)
                vgg_on_enc = self.vgg_loss(encoded_images)
                g_loss_enc = self.mse_loss(vgg_on_cov, vgg_on_enc)

            g_loss_dec = self.mse_loss(decoded_messages, messages)


            g_loss = self.config.adversarial_loss * g_loss_adv + self.config.encoder_loss * g_loss_enc \
                     + self.config.decoder_loss * g_loss_dec
            g_loss.backward()
            self.optimizer_enc_dec.step()

        decoded_rounded = decoded_messages.detach().cpu().numpy().round().clip(
            0, 1)
        bitwise_avg_err = np.sum(
            np.abs(decoded_rounded - messages.detach().cpu().numpy())) / (
                batch_size * messages.shape[1])

        losses = {
            'loss           ': g_loss.item(),
            'encoder_mse    ': g_loss_enc.item(),
            'dec_mse        ': g_loss_dec.item(),
            'bitwise-error  ': bitwise_avg_err,
            'adversarial_bce': g_loss_adv.item(),
            'discr_cover_bce': d_loss_on_cover.item(),
            'discr_encod_bce': d_loss_on_encoded.item()
        }
        return losses, (encoded_images, noised_images, decoded_messages)

    def validate_on_batch(self, batch: list):
        """
        Runs validation on a single batch of data consisting of images and messages
        :param batch: batch of validation data, in form [images, messages]
        :return: dictionary of error metrics from Encoder, Decoder, and Discriminator on the current batch
        """

        # if TensorboardX logging is enabled, save some of the tensors.
        if self.tb_logger is not None:
            encoder_final = self.encoder_decoder.encoder._modules[
                'final_layer']
            self.tb_logger.add_tensor('weights/encoder_out',
                                      encoder_final.weight)
            decoder_final = self.encoder_decoder.decoder._modules['linear']
            self.tb_logger.add_tensor('weights/decoder_out',
                                      decoder_final.weight)
            discrim_final = self.discriminator._modules['linear']
            self.tb_logger.add_tensor('weights/discrim_out',
                                      discrim_final.weight)

        images, messages = batch
        batch_size = images.shape[0]

        with torch.no_grad():
            d_on_cover = self.discriminator(images)
            d_target_label_cover = torch.full((batch_size, 1),
                                              self.cover_label,
                                              device=self.device)
            d_on_cover = self.discriminator(images)
            d_loss_on_cover = self.bce_with_logits_loss(
                d_on_cover, d_target_label_cover)

            encoded_images, noised_images, decoded_messages = self.encoder_decoder(
                images, messages)
            d_target_label_encoded = torch.full((batch_size, 1),
                                                self.encoded_label,
                                                device=self.device)
            d_on_encoded = self.discriminator(encoded_images)
            d_loss_on_encoded = self.bce_with_logits_loss(
                d_on_encoded, d_target_label_encoded)

            g_target_label_encoded = torch.full((batch_size, 1),
                                                self.cover_label,
                                                device=self.device)
            d_on_encoded_for_enc = self.discriminator(encoded_images)
            g_loss_adv = self.bce_with_logits_loss(d_on_encoded_for_enc,
                                                   g_target_label_encoded)

            if self.vgg_loss == None:
                g_loss_enc = self.mse_loss(encoded_images, images)
            else:
                vgg_on_cov = self.vgg_loss(images)
                vgg_on_enc = self.vgg_loss(encoded_images)
                g_loss_enc = self.mse_loss(vgg_on_cov, vgg_on_enc)

            g_loss_dec = self.mse_loss(decoded_messages, messages)
            g_loss = self.config.adversarial_loss * g_loss_adv + self.config.encoder_loss * g_loss_enc \
                     + self.config.decoder_loss * g_loss_dec

        decoded_rounded = decoded_messages.detach().cpu().numpy().round().clip(
            0, 1)
        bitwise_avg_err = np.sum(
            np.abs(decoded_rounded - messages.detach().cpu().numpy())) / (
                batch_size * messages.shape[1])

        losses = {
            'loss           ': g_loss.item(),
            'encoder_mse    ': g_loss_enc.item(),
            'dec_mse        ': g_loss_dec.item(),
            'bitwise-error  ': bitwise_avg_err,
            'adversarial_bce': g_loss_adv.item(),
            'discr_cover_bce': d_loss_on_cover.item(),
            'discr_encod_bce': d_loss_on_encoded.item()
        }
        return losses, (encoded_images, noised_images, decoded_messages)

    def to_stirng(self):
        return '{}\n{}'.format(str(self.encoder_decoder),
                               str(self.discriminator))
Example #9
0
def train(encoder_decoder: EncoderDecoder, train_data_loader: DataLoader,
          model_name, val_data_loader: DataLoader, keep_prob,
          teacher_forcing_schedule, lr, max_length, use_decay, data_path):

    global_step = 0
    loss_function = torch.nn.NLLLoss(ignore_index=0)
    optimizer = optim.Adam(encoder_decoder.parameters(), lr=lr)
    model_path = './saved/' + model_name + '/'

    if (use_decay == False):
        gamma = 1.0
    else:
        gamma = 0.5
    scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

    #val_loss, val_bleu_score = evaluate(encoder_decoder, val_data_loader)

    best_bleu = 0.0

    for epoch, teacher_forcing in enumerate(teacher_forcing_schedule):
        #scheduler.step()
        print('epoch %i' % (epoch), flush=True)
        print('lr: ' + str(scheduler.get_lr()))

        for batch_idx, (input_idxs, target_idxs, input_tokens,
                        target_tokens) in enumerate(tqdm(train_data_loader)):
            # input_idxs and target_idxs have dim (batch_size x max_len)
            # they are NOT sorted by length
            '''
            print(input_idxs[0])
            print(input_tokens[0])
            print(target_idxs[0])
            print(target_tokens[0])
            '''
            lengths = (input_idxs != 0).long().sum(dim=1)
            sorted_lengths, order = torch.sort(lengths, descending=True)

            input_variable = Variable(input_idxs[order, :][:, :max(lengths)])
            target_variable = Variable(target_idxs[order, :])

            optimizer.zero_grad()
            output_log_probs, output_seqs = encoder_decoder(
                input_variable,
                list(sorted_lengths),
                targets=target_variable,
                keep_prob=keep_prob,
                teacher_forcing=teacher_forcing)
            batch_size = input_variable.shape[0]

            flattened_outputs = output_log_probs.view(batch_size * max_length,
                                                      -1)

            batch_loss = loss_function(flattened_outputs,
                                       target_variable.contiguous().view(-1))

            batch_loss.backward()
            optimizer.step()
            batch_outputs = trim_seqs(output_seqs)

            batch_targets = [[list(seq[seq > 0])]
                             for seq in list(to_np(target_variable))]

            #batch_bleu_score = corpus_bleu(batch_targets, batch_outputs, smoothing_function=SmoothingFunction().method2)
            batch_bleu_score = corpus_bleu(batch_targets, batch_outputs)
            '''
            if global_step < 10 or (global_step % 10 == 0 and global_step < 100) or (global_step % 100 == 0 and epoch < 2):
                input_string = "Amy, Please schedule a meeting with Marcos on Tuesday April 3rd. Adam Kleczewski"
                output_string = encoder_decoder.get_response(input_string)
                writer.add_text('schedule', output_string, global_step=global_step)

                input_string = "Amy, Please cancel this meeting. Adam Kleczewski"
                output_string = encoder_decoder.get_response(input_string)
                writer.add_text('cancel', output_string, global_step=global_step)
            '''

            if global_step % 100 == 0:

                writer.add_scalar('train_batch_loss', batch_loss, global_step)
                writer.add_scalar('train_batch_bleu_score', batch_bleu_score,
                                  global_step)

                for tag, value in encoder_decoder.named_parameters():
                    tag = tag.replace('.', '/')
                    writer.add_histogram('weights/' + tag,
                                         value,
                                         global_step,
                                         bins='doane')
                    writer.add_histogram('grads/' + tag,
                                         to_np(value.grad),
                                         global_step,
                                         bins='doane')

            global_step += 1

            debug = False

            if (debug):
                if batch_idx == 5:
                    break

        val_loss, val_bleu_score = evaluate(encoder_decoder, val_data_loader)

        writer.add_scalar('val_loss', val_loss, global_step=global_step)
        writer.add_scalar('val_bleu_score',
                          val_bleu_score,
                          global_step=global_step)

        encoder_embeddings = encoder_decoder.encoder.embedding.weight.data
        encoder_vocab = encoder_decoder.lang.tok_to_idx.keys()
        writer.add_embedding(encoder_embeddings,
                             metadata=encoder_vocab,
                             global_step=0,
                             tag='encoder_embeddings')

        decoder_embeddings = encoder_decoder.decoder.embedding.weight.data
        decoder_vocab = encoder_decoder.lang.tok_to_idx.keys()
        writer.add_embedding(decoder_embeddings,
                             metadata=decoder_vocab,
                             global_step=0,
                             tag='decoder_embeddings')
        '''
        input_string = "Amy, Please schedule a meeting with Marcos on Tuesday April 3rd. Adam Kleczewski"
        output_string = encoder_decoder.get_response(input_string)
        writer.add_text('schedule', output_string, global_step=global_step)

        input_string = "Amy, Please cancel this meeting. Adam Kleczewski"
        output_string = encoder_decoder.get_response(input_string)
        writer.add_text('cancel', output_string, global_step=global_step)
        '''

        calc_bleu_score = get_bleu(encoder_decoder, data_path, None, 'dev')
        print('val loss: %.5f, val BLEU score: %.5f' %
              (val_loss, calc_bleu_score),
              flush=True)
        if (calc_bleu_score > best_bleu):
            print("Best BLEU score! Saving model...")
            best_bleu = calc_bleu_score
            torch.save(
                encoder_decoder, "%s%s_%i_%.3f.pt" %
                (model_path, model_name, epoch, calc_bleu_score))

        print('-' * 100, flush=True)

        scheduler.step()
Example #10
0
def main(model_name,
         use_cuda,
         batch_size,
         teacher_forcing_schedule,
         keep_prob,
         val_size,
         lr,
         decoder_type,
         vocab_limit,
         hidden_size,
         embedding_size,
         max_length,
         data_path,
         use_decay,
         saved_model_path,
         seed=42):

    model_path = './saved/' + model_name + '/'

    # TODO: Change logging to reflect loaded parameters

    print("training %s with use_cuda=%s, batch_size=%i" %
          (model_name, use_cuda, batch_size),
          flush=True)
    print("teacher_forcing_schedule=", teacher_forcing_schedule, flush=True)
    print(
        "keep_prob=%f, val_size=%f, lr=%f, decoder_type=%s, vocab_limit=%i, hidden_size=%i, embedding_size=%i, max_length=%i, seed=%i"
        % (keep_prob, val_size, lr, decoder_type, vocab_limit, hidden_size,
           embedding_size, max_length, seed),
        flush=True)

    if os.path.isdir(model_path):

        print("loading encoder and decoder from model_path", flush=True)
        #encoder_decoder = torch.load(model_path + model_name + '.pt')
        encoder_decoder = torch.load(saved_model_path)

        print("creating training and validation datasets with saved languages",
              flush=True)
        train_dataset = SequencePairDataset(
            data_path=data_path,
            maxlen=max_length,
            lang=encoder_decoder.lang,
            use_cuda=use_cuda,
            val_size=val_size,
            use_extended_vocab=(encoder_decoder.decoder_type == 'copy'),
            data_type='train')

        val_dataset = SequencePairDataset(
            data_path=data_path,
            maxlen=max_length,
            lang=encoder_decoder.lang,
            use_cuda=use_cuda,
            val_size=val_size,
            use_extended_vocab=(encoder_decoder.decoder_type == 'copy'),
            data_type='dev')

    else:
        os.mkdir(model_path)

        print("creating training and validation datasets", flush=True)
        train_dataset = SequencePairDataset(
            data_path=data_path,
            maxlen=max_length,
            vocab_limit=vocab_limit,
            use_cuda=use_cuda,
            val_size=val_size,
            seed=seed,
            use_extended_vocab=(decoder_type == 'copy'),
            data_type='train')

        val_dataset = SequencePairDataset(
            data_path=data_path,
            maxlen=max_length,
            lang=train_dataset.lang,
            use_cuda=use_cuda,
            val_size=val_size,
            seed=seed,
            use_extended_vocab=(decoder_type == 'copy'),
            data_type='dev')

        print("creating encoder-decoder model", flush=True)
        encoder_decoder = EncoderDecoder(train_dataset.lang, max_length,
                                         hidden_size, embedding_size,
                                         decoder_type)

        torch.save(encoder_decoder, model_path + '/%s.pt' % model_name)

    if use_cuda:
        encoder_decoder = encoder_decoder.cuda()
    else:
        encoder_decoder = encoder_decoder.cpu()
    train_data_loader = DataLoader(train_dataset,
                                   batch_size=batch_size,
                                   shuffle=True)
    val_data_loader = DataLoader(val_dataset, batch_size=batch_size)

    train(encoder_decoder, train_data_loader, model_name, val_data_loader,
          keep_prob, teacher_forcing_schedule, lr,
          encoder_decoder.decoder.max_length, use_decay, data_path)
Example #11
0
    def __init__(self, args):
        # print('__init__() :', args)

        if 'cpu' not in args:
            args['cpu'] = False
        if 'vocab' not in args:
            args['vocab'] = None
        if 'ext_features' not in args:
            args['ext_features'] = None
        if 'ext_persist_features' not in args:
            args['ext_persist_features'] = None
        if 'lemma_pos_rules' not in args:
            args['lemma_pos_rules'] = None
        
        global device
        device = torch.device('cuda' if torch.cuda.is_available() and
                              not args['cpu'] else 'cpu')

        vi = sys.version_info
        print('Python version {}.{}.{}, torch version {}'.format(vi[0], vi[1], vi[2],
                                                                 torch.__version__))
        
        # Build models
        print('Bulding model for device {}'.format(device.type))

        try:
            self.state = torch.load(args['model'], map_location=device)
        except AttributeError:
            print('WARNING: Old model found. '+
                  'Please use model_update.py in the model before executing this script.')
            exit(1)
            
        self.params = ModelParams(self.state)
        if args['ext_features']:
            self.params.update_ext_features(args['ext_features'])
        if args['ext_persist_features']:
            self.params.update_ext_persist_features(args['ext_persist_features'])

        print('Loaded model parameters from <{}>:'.format(args['model']))
        print(self.params)

        # Load the vocabulary:
        if args['vocab'] is not None:
            # Loading vocabulary from file path supplied by the user:
            args_attr = AttributeDict(args)
            self.vocab = get_vocab(args_attr)
        elif self.params.vocab is not None:
            print('Loading vocabulary stored in the model file.')
            self.vocab = self.params.vocab
        else:
            print('ERROR: you must either load a model that contains vocabulary or '
                  'specify a vocabulary with the --vocab option!')
            sys.exit(1)

            print('Size of the vocabulary is {}'.format(len(self.vocab)))

        ef_dims = None
        self.model = EncoderDecoder(self.params, device, len(self.vocab), self.state, ef_dims).eval()
        # print(self.params.ext_features_dim)

        self.lemma_pos_rules = {}
        self.pos_names = set()
        if args['lemma_pos_rules'] is not None:
            self.read_lemma_pos_rules(args['lemma_pos_rules'])
Example #12
0
class HiDDen(object):
    def __init__(self, config: HiDDenConfiguration, device: torch.device):
        self.enc_dec = EncoderDecoder(config).to(device)
        self.discr = Discriminator(config).to(device)
        self.opt_enc_dec = torch.optim.Adam(self.enc_dec.parameters())
        self.opt_discr = torch.optim.Adam(self.discr.parameters())

        self.config = config
        self.device = device
        self.bce_with_logits_loss = nn.BCEWithLogitsLoss().to(device)
        self.mse_loss = nn.MSELoss().to(device)

        self.cover_label = 1
        self.encod_label = 0

    def train_on_batch(self, batch: list):
        '''
        Trains the network on a single batch consistring images and messages
        '''
        images, messages = batch
        batch_size = images.shape[0]
        self.enc_dec.train()
        self.discr.train()

        with torch.enable_grad():
            # ---------- Train the discriminator----------
            self.opt_discr.zero_grad()

            # train on cover
            d_target_label_cover = torch.full((batch_size, 1),
                                              self.cover_label,
                                              device=self.device)
            d_target_label_encoded = torch.full((batch_size, 1),
                                                self.encod_label,
                                                device=self.device)
            g_target_label_encoded = torch.full((batch_size, 1),
                                                self.cover_label,
                                                device=self.device)

            d_on_cover = self.discr(images)
            d_loss_on_cover = self.bce_with_logits_loss(
                d_on_cover, d_target_label_cover)
            d_loss_on_cover.backward()

            # train on fake
            encoded_images, decoded_messages = self.enc_dec(images, messages)
            d_on_encoded = self.discr(encoded_images.detach())
            d_loss_on_encod = self.bce_with_logits_loss(
                d_on_encoded, d_target_label_encoded)
            d_loss_on_encod.backward()
            self.opt_discr.step()

            #---------- Train the generator----------
            self.opt_enc_dec.zero_grad()

            d_on_encoded_for_enc = self.discr(encoded_images)
            g_loss_adv = self.bce_with_logits_loss(d_on_encoded_for_enc,
                                                   g_target_label_encoded)
            g_loss_enc = self.mse_loss(encoded_images, images)
            g_loss_dec = self.mse_loss(decoded_messages, messages)

            g_loss = self.config.adversarial_loss * g_loss_adv \
                    + self.config.encoder_loss * g_loss_enc \
                    + self.config.decoder_loss * g_loss_dec
            g_loss.backward()
            self.opt_enc_dec.step()

        decoded_rounded = decoded_messages.detach().cpu().numpy().round().clip(
            0, 1)
        bitwise_err = np.sum(np.abs(decoded_rounded - messages.detach().cpu().numpy())) \
                      / (batch_size * messages.shape[1])

        losses = {
            'loss': g_loss.item(),
            'encoder_mse': g_loss_enc.item(),
            'decoder_mse': g_loss_dec.item(),
            'bitwise-error': bitwise_err,
            'adversarial_bce': g_loss_adv.item(),
            'discr_cover_bce': d_loss_on_cover.item(),
            'discr_encod_bce': d_loss_on_encod.item()
        }

        return losses, (encoded_images, decoded_messages)

    def validate_on_batch(self, batch: list):
        '''Run validation on a batch consist of [images, messages]'''
        images, messages = batch
        batch_size = images.shape[0]

        self.enc_dec.eval()
        self.discr.eval()

        with torch.no_grad():
            d_target_label_cover = torch.full((batch_size, 1),
                                              self.cover_label,
                                              device=self.device)
            d_target_label_encoded = torch.full((batch_size, 1),
                                                self.encod_label,
                                                device=self.device)
            g_target_label_encoded = torch.full((batch_size, 1),
                                                self.cover_label,
                                                device=self.device)

            d_on_cover = self.discr(images)
            d_loss_on_cover = self.bce_with_logits_loss(
                d_on_cover, d_target_label_cover)

            encoded_images, decoded_messages = self.enc_dec(images, messages)
            d_on_encoded = self.discr(encoded_images)
            d_loss_on_encod = self.bce_with_logits_loss(
                d_on_encoded, d_target_label_encoded)

            d_on_encoded_for_enc = self.discr(encoded_images)
            g_loss_adv = self.bce_with_logits_loss(d_on_encoded_for_enc,
                                                   g_target_label_encoded)
            g_loss_enc = self.mse_loss(encoded_images, images)
            g_loss_dec = self.mse_loss(decoded_messages, messages)

            g_loss = self.config.adversarial_loss * g_loss_adv \
                    + self.config.encoder_loss * g_loss_enc \
                    + self.config.decoder_loss * g_loss_dec

        decoded_rounded = decoded_messages.detach().cpu().numpy().round().clip(
            0, 1)
        bitwise_err = np.sum(np.abs(decoded_rounded - messages.detach().cpu().numpy()))\
                     / (batch_size * messages.shape[1])

        losses = {
            'loss': g_loss.item(),
            'encoder_mse': g_loss_enc.item(),
            'decoder_mse': g_loss_dec.item(),
            'bitwise-err': bitwise_err,
            'adversarial_bce': g_loss_adv.item(),
            'discr_cover_bce': d_loss_on_cover.item(),
            'discr_enced_bce': d_loss_on_encod.item()
        }

        return losses, (encoded_images, decoded_messages)

    def to_stirng(self):
        return f'{str(self.enc_dec)}\n{str(self.discr)}'
Example #13
0
def train(encoder_decoder: EncoderDecoder, train_data_loader: DataLoader,
          model_name, val_data_loader: DataLoader, keep_prob,
          teacher_forcing_schedule, lr, max_length, device,
          test_data_loader: DataLoader):

    global_step = 0
    loss_function = torch.nn.NLLLoss(ignore_index=0)
    optimizer = optim.Adam(encoder_decoder.parameters(), lr=lr)
    model_path = './model/' + model_name + '/'
    trained_model = encoder_decoder

    for epoch, teacher_forcing in enumerate(teacher_forcing_schedule):
        print('epoch %i' % epoch, flush=True)
        correct_predictions = 0.0
        all_predictions = 0.0
        for batch_idx, (input_idxs, target_idxs, input_tokens,
                        target_tokens) in enumerate(tqdm(train_data_loader)):
            # Empty the cache at each batch
            torch.cuda.empty_cache()
            # input_idxs and target_idxs have dim (batch_size x max_len)
            # they are NOT sorted by length

            lengths = (input_idxs != 0).long().sum(dim=1)
            sorted_lengths, order = torch.sort(lengths, descending=True)

            input_variable = input_idxs[order, :][:, :max(lengths)]
            input_variable = input_variable.to(device)
            target_variable = target_idxs[order, :]
            target_variable = target_variable.to(device)

            optimizer.zero_grad()
            output_log_probs, output_seqs = encoder_decoder(
                input_variable,
                list(sorted_lengths),
                targets=target_variable,
                keep_prob=keep_prob,
                teacher_forcing=teacher_forcing)

            batch_size = input_variable.shape[0]

            output_sentences = output_seqs.squeeze(2)

            flattened_outputs = output_log_probs.view(batch_size * max_length,
                                                      -1)

            batch_loss = loss_function(flattened_outputs,
                                       target_variable.contiguous().view(-1))
            batch_outputs = trim_seqs(output_seqs)

            batch_inputs = [[list(seq[seq > 0])]
                            for seq in list(to_np(input_variable))]
            batch_targets = [[list(seq[seq > 0])]
                             for seq in list(to_np(target_variable))]

            for i in range(len(batch_outputs)):
                y_i = batch_outputs[i]
                tgt_i = batch_targets[i][0]

                if y_i == tgt_i:
                    correct_predictions += 1.0

                all_predictions += 1.0

            batch_loss.backward()
            optimizer.step()

            batch_bleu_score = corpus_bleu(
                batch_targets,
                batch_outputs,
                smoothing_function=SmoothingFunction().method1)

            if global_step % 100 == 0:

                writer.add_scalar('train_batch_loss', batch_loss, global_step)
                writer.add_scalar('train_batch_bleu_score', batch_bleu_score,
                                  global_step)

                for tag, value in encoder_decoder.named_parameters():
                    tag = tag.replace('.', '/')
                    writer.add_histogram('weights/' + tag,
                                         value,
                                         global_step,
                                         bins='doane')
                    writer.add_histogram('grads/' + tag,
                                         to_np(value.grad),
                                         global_step,
                                         bins='doane')

            global_step += 1

        encoder_embeddings = encoder_decoder.encoder.embedding.weight.data
        encoder_vocab = encoder_decoder.lang.tok_to_idx.keys()
        writer.add_embedding(encoder_embeddings,
                             metadata=encoder_vocab,
                             global_step=0,
                             tag='encoder_embeddings')

        decoder_embeddings = encoder_decoder.decoder.embedding.weight.data
        decoder_vocab = encoder_decoder.lang.tok_to_idx.keys()
        writer.add_embedding(decoder_embeddings,
                             metadata=decoder_vocab,
                             global_step=0,
                             tag='decoder_embeddings')

        print('training accuracy %.5f' %
              (100.0 * (correct_predictions / all_predictions)))
        torch.save(encoder_decoder,
                   "%s%s_%i.pt" % (model_path, model_name, epoch))
        trained_model = encoder_decoder

        print('-' * 100, flush=True)

    torch.save(encoder_decoder, "%s%s_final.pt" % (model_path, model_name))
    return trained_model
Example #14
0
def main(model_name,
         use_cuda,
         batch_size,
         teacher_forcing_schedule,
         keep_prob,
         val_size,
         lr,
         decoder_type,
         vocab_limit,
         hidden_size,
         embedding_size,
         max_length,
         train_data,
         test_data,
         device,
         seed=42):

    # TODO: Change logging to reflect loaded parameters
    model_path = './model/' + model_name
    print("training %s with use_cuda=%s, batch_size=%i" %
          (model_name, use_cuda, batch_size),
          flush=True)
    print("teacher_forcing_schedule=", teacher_forcing_schedule, flush=True)
    print(
        "keep_prob=%f, val_size=%f, lr=%f, decoder_type=%s, vocab_limit=%i, hidden_size=%i, embedding_size=%i, max_length=%i, seed=%i"
        % (keep_prob, val_size, lr, decoder_type, vocab_limit, hidden_size,
           embedding_size, max_length, seed),
        flush=True)

    glove = get_glove()

    currentDT = datetime.datetime.now()
    seeds = [8, 23, 10, 41, 32]

    all_means_seen = []
    all_means_unseen = []
    all_means_mixed = []

    print('Testing using ' + test_data)

    mixed_src, mixed_tgt = load_complete_data('twophrase_1seen1unseen_clean')
    test_src, test_tgt = load_complete_data(test_data)

    # Repeat for the listed iterations #
    for it in range(5):
        print("creating training, and validation datasets", flush=True)
        all_datasets = generateKFoldDatasets(
            train_data,
            vocab_limit=vocab_limit,
            use_extended_vocab=(decoder_type == 'copy'),
            seed=seeds[it])
        all_accuracies_seen = []
        all_accuracies_unseen = []
        all_accuracies_mixed = []

        # Repeat for all k dataset folds #
        for k in range(len(all_datasets)):
            train_dataset = all_datasets[k][0]
            val_dataset = all_datasets[k][1]

            print("creating {}th encoder-decoder model".format(k), flush=True)
            encoder_decoder = EncoderDecoder(train_dataset.lang, max_length,
                                             hidden_size, embedding_size,
                                             decoder_type, device, glove)

            test_dataset = SequencePairDataset(
                test_src,
                test_tgt,
                lang=train_dataset.lang,
                use_extended_vocab=(encoder_decoder.decoder_type == 'copy'))

            mixed_dataset = SequencePairDataset(
                mixed_src,
                mixed_tgt,
                lang=train_dataset.lang,
                use_extended_vocab=(encoder_decoder.decoder_type == 'copy'))

            encoder_decoder = encoder_decoder.to(device)

            train_data_loader = DataLoader(train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True,
                                           num_workers=12)
            val_data_loader = DataLoader(val_dataset, batch_size=batch_size)
            mixed_data_loader = DataLoader(mixed_dataset,
                                           batch_size=batch_size)
            test_data_loader = DataLoader(test_dataset, batch_size=batch_size)

            # Make a dir to hold this fold's model
            os.mkdir(model_path + str(k) + str(it) + '/')
            # Train the model
            train(encoder_decoder, train_data_loader,
                  model_name + str(k) + str(it), val_data_loader, keep_prob,
                  teacher_forcing_schedule, lr,
                  encoder_decoder.decoder.max_length, device, test_data_loader)

            # Change the model path so the proper model can be loaded
            model_path = './model/' + model_name + str(k) + str(it) + '/'
            # Load the trained model before testing, just in case
            trained_model = torch.load(model_path + model_name +
                                       '{}{}_final.pt'.format(k, it))

            ### WRITE TO ALL THE LOG FILES ##
            s_f = open(
                "./logs/log_" + model_name + "seen" +
                currentDT.strftime("%Y%m%d%H%M%S") + ".txt", "w")
            s_f.write("TRAINING MODEL {}\nUSING SEED VALUE {}\n\n".format(
                model_name, seeds[it]))
            sc_f = open(
                "./logs/log_" + model_name + "seencorrect" +
                currentDT.strftime("%Y%m%d%H%M%S") + ".txt", "w")
            with torch.no_grad():
                seen_accuracy = test(trained_model,
                                     val_data_loader,
                                     encoder_decoder.decoder.max_length,
                                     device,
                                     log_files=(s_f, sc_f))
            s_f.close()
            sc_f.close()

            m_f = open(
                "./logs/log_" + model_name + "1seen1unseen" +
                currentDT.strftime("%Y%m%d%H%M%S") + ".txt", "w")
            m_f.write("TRAINING MODEL {}\nUSING SEED VALUE {}\n\n".format(
                model_name, seeds[it]))
            mc_f = open(
                "./logs/log_" + model_name + "1seen1unseencorrect" +
                currentDT.strftime("%Y%m%d%H%M%S") + ".txt", "w")
            with torch.no_grad():
                mixed_accuracy = test(trained_model,
                                      mixed_data_loader,
                                      encoder_decoder.decoder.max_length,
                                      device,
                                      log_files=(m_f, mc_f))
            m_f.close()
            mc_f.close()

            u_f = open(
                "./logs/log_" + model_name + "unseen" +
                currentDT.strftime("%Y%m%d%H%M%S") + ".txt", "w")
            u_f.write("TRAINING MODEL {}\nUSING SEED VALUE {}\n\n".format(
                model_name, seeds[it]))
            uc_f = open(
                "./logs/log_" + model_name + "unseencorrect" +
                currentDT.strftime("%Y%m%d%H%M%S") + ".txt", "w")
            with torch.no_grad():
                unseen_accuracy = test(trained_model,
                                       test_data_loader,
                                       encoder_decoder.decoder.max_length,
                                       device,
                                       log_files=(u_f, uc_f))
            u_f.close()
            uc_f.close()

            # Append the accuracies for this model so they can be averaged
            all_accuracies_seen.append(seen_accuracy)
            all_accuracies_unseen.append(unseen_accuracy)
            all_accuracies_mixed.append(mixed_accuracy)

            # Reset model path
            model_path = './model/' + model_name

        # Average all the accuracies for this round of cross val and record the mean
        s_mean = sum(all_accuracies_seen) / len(all_accuracies_seen)
        m_mean = sum(all_accuracies_mixed) / len(all_accuracies_mixed)
        u_mean = sum(all_accuracies_unseen) / len(all_accuracies_unseen)
        all_means_seen.append(s_mean)
        all_means_mixed.append(m_mean)
        all_means_unseen.append(u_mean)

    # PRINT FINAL CROSSVAL RESULTS #
    currentDT = datetime.datetime.now()
    acc_f = open(
        "./results/results_" + model_name +
        currentDT.strftime("%Y%m%d%H%M%S") + ".txt", "w")

    # Seen accuracies
    acc_f.write("SEEN ACCURACIES:\n")
    for acc in all_means_seen:
        acc_f.write("{}\n".format(acc))

    s_mean = sum(all_means_seen) / len(all_means_seen)
    s_std_dev = math.sqrt(
        sum([math.pow(x - s_mean, 2)
             for x in all_means_seen]) / len(all_means_seen))
    acc_f.write("\nMean: {}\nStandard Deviation: {}\n".format(
        s_mean, s_std_dev))

    # One seen one unseen accuracies
    acc_f.write("ONE SEEN ONE UNSEEN ACCURACIES:\n")
    for acc in all_means_mixed:
        acc_f.write("{}\n".format(acc))

    m_mean = sum(all_means_mixed) / len(all_means_mixed)
    m_std_dev = math.sqrt(
        sum([math.pow(x - m_mean, 2)
             for x in all_means_mixed]) / len(all_means_mixed))
    acc_f.write("\nMean: {}\nStandard Deviation: {}\n".format(
        m_mean, m_std_dev))

    # Unseen accuracies
    acc_f.write("\nUNSEEN ACCURACIES:\n")
    for acc in all_means_unseen:
        acc_f.write("{}\n".format(acc))

    u_mean = sum(all_means_unseen) / len(all_means_unseen)
    u_std_dev = math.sqrt(
        sum([math.pow(x - u_mean, 2)
             for x in all_means_unseen]) / len(all_means_unseen))
    acc_f.write("\nMean: {}\nStandard Deviation: {}\n".format(
        u_mean, u_std_dev))
    acc_f.close()
Example #15
0
def main(args):
    if args.model_name is not None:
        print('Preparing to train model: {}'.format(args.model_name))

    global device
    device = torch.device(
        'cuda' if torch.cuda.is_available() and not args.cpu else 'cpu')

    sc_will_happen = args.self_critical_from_epoch != -1

    if args.validate is None and args.lr_scheduler == 'ReduceLROnPlateau':
        print(
            'ERROR: you need to enable validation in order to use default lr_scheduler (ReduceLROnPlateau)'
        )
        print('Hint: use something like --validate=coco:val2017')
        sys.exit(1)

    # Create model directory
    if not os.path.exists(args.model_path):
        os.makedirs(args.model_path)

    # Image preprocessing, normalization for the pretrained resnet
    transform = transforms.Compose([
        # transforms.Resize((256, 256)),
        transforms.RandomCrop(args.crop_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    scorers = {}
    if args.validation_scoring is not None or sc_will_happen:
        assert not (
            args.validation_scoring is None and sc_will_happen
        ), "Please provide a metric when using self-critical training"
        for s in args.validation_scoring.split(','):
            s = s.lower().strip()
            if s == 'cider':
                from eval.cider import Cider
                scorers['CIDEr'] = Cider()
            if s == 'ciderd':
                from eval.ciderD.ciderD import CiderD
                scorers['CIDEr-D'] = CiderD(df=args.cached_words)

    ########################
    # Set Model parameters #
    ########################

    # Store parameters gotten from arguments separately:
    arg_params = ModelParams.fromargs(args)

    print("Model parameters inferred from command arguments: ")
    print(arg_params)
    start_epoch = 0

    ###############################
    # Load existing model state   #
    # and update Model parameters #
    ###############################

    state = None

    if args.load_model:
        try:
            state = torch.load(args.load_model, map_location=device)
        except AttributeError:
            print(
                'WARNING: Old model found. Please use model_update.py in the model before executing this script.'
            )
            exit(1)
        new_external_features = arg_params.features.external

        params = ModelParams(state, arg_params=arg_params)
        if len(new_external_features
               ) and params.features.external != new_external_features:
            print('WARNING: external features changed: ',
                  params.features.external, new_external_features)
            print('Updating feature paths...')
            params.update_ext_features(new_external_features)
        start_epoch = state['epoch']
        print('Loaded model {} at epoch {}'.format(args.load_model,
                                                   start_epoch))
    else:
        params = arg_params
        params.command_history = []

    if params.rnn_hidden_init == 'from_features' and params.skip_start_token:
        print(
            "ERROR: Please remove --skip_start_token if you want to use image features "
            " to initialize hidden and cell states. <start> token is needed to trigger "
            " the process of sequence generation, since we don't have image features "
            " embedding as the first input token.")
        sys.exit(1)

    # Force set the following hierarchical model parameters every time:
    if arg_params.hierarchical_model:
        params.hierarchical_model = True
        params.max_sentences = arg_params.max_sentences
        params.weight_sentence_loss = arg_params.weight_sentence_loss
        params.weight_word_loss = arg_params.weight_word_loss
        params.dropout_stopping = arg_params.dropout_stopping
        params.dropout_fc = arg_params.dropout_fc
        params.coherent_sentences = arg_params.coherent_sentences
        params.coupling_alpha = arg_params.coupling_alpha
        params.coupling_beta = arg_params.coupling_beta

    assert args.replace or \
        not os.path.isdir(os.path.join(args.output_root, args.model_path, get_model_name(args, params))) or \
        not (args.load_model and not args.validate_only), \
        '{} already exists. If you want to replace it or resume training please use --replace flag. ' \
        'If you want to validate a loaded model without training it, use --validate_only flag.'  \
        'Otherwise specify a different model name using --model_name flag.'\
        .format(os.path.join(args.output_root, args.model_path, get_model_name(args, params)))

    if args.load_model:
        print("Final model parameters (loaded model + command arguments): ")
        print(params)

    ##############################
    # Load dataset configuration #
    ##############################

    dataset_configs = DatasetParams(args.dataset_config_file)

    if args.dataset is None and not args.validate_only:
        print('ERROR: No dataset selected!')
        print(
            'Please supply a training dataset with the argument --dataset DATASET'
        )
        print('The following datasets are configured in {}:'.format(
            args.dataset_config_file))
        for ds, _ in dataset_configs.config.items():
            if ds not in ('DEFAULT', 'generic'):
                print(' ', ds)
        sys.exit(1)

    if args.validate_only:
        if args.load_model is None:
            print(
                'ERROR: for --validate_only you need to specify a model to evaluate using --load_model MODEL'
            )
            sys.exit(1)
    else:
        dataset_params = dataset_configs.get_params(args.dataset)

        for i in dataset_params:
            i.config_dict['no_tokenize'] = args.no_tokenize
            i.config_dict['show_tokens'] = args.show_tokens
            i.config_dict['skip_start_token'] = params.skip_start_token

            if params.hierarchical_model:
                i.config_dict['hierarchical_model'] = True
                i.config_dict['max_sentences'] = params.max_sentences
                i.config_dict['crop_regions'] = False

    if args.validate is not None:
        validation_dataset_params = dataset_configs.get_params(args.validate)
        for i in validation_dataset_params:
            i.config_dict['no_tokenize'] = args.no_tokenize
            i.config_dict['show_tokens'] = args.show_tokens
            i.config_dict['skip_start_token'] = params.skip_start_token

            if params.hierarchical_model:
                i.config_dict['hierarchical_model'] = True
                i.config_dict['max_sentences'] = params.max_sentences
                i.config_dict['crop_regions'] = False

    #######################
    # Load the vocabulary #
    #######################

    # For pre-trained models attempt to obtain
    # saved vocabulary from the model itself:
    if args.load_model and params.vocab is not None:
        print("Loading vocabulary from the model file:")
        vocab = params.vocab
    else:
        if args.vocab is None:
            print(
                "ERROR: You must specify the vocabulary to be used for training using "
                "--vocab flag.\nTry --vocab AUTO if you want the vocabulary to be "
                "either generated from the training dataset or loaded from cache."
            )
            sys.exit(1)
        print("Loading / generating vocabulary:")
        vocab = get_vocab(args, dataset_params)

    print('Size of the vocabulary is {}'.format(len(vocab)))

    ##########################
    # Initialize data loader #
    ##########################

    ext_feature_sets = [
        params.features.external, params.persist_features.external
    ]
    if not args.validate_only:
        print('Loading dataset: {} with {} workers'.format(
            args.dataset, args.num_workers))
        if params.skip_start_token:
            print("Skipping the use of <start> token...")
        data_loader, ef_dims = get_loader(
            dataset_params,
            vocab,
            transform,
            args.batch_size,
            shuffle=True,
            num_workers=args.num_workers,
            ext_feature_sets=ext_feature_sets,
            skip_images=not params.has_internal_features(),
            verbose=args.verbose,
            unique_ids=sc_will_happen)
        if sc_will_happen:
            gts_sc = get_ground_truth_captions(data_loader.dataset)

    gts_sc_valid = None
    if args.validate is not None:
        valid_loader, ef_dims = get_loader(
            validation_dataset_params,
            vocab,
            transform,
            args.batch_size,
            shuffle=True,
            num_workers=args.num_workers,
            ext_feature_sets=ext_feature_sets,
            skip_images=not params.has_internal_features(),
            verbose=args.verbose)
        gts_sc_valid = get_ground_truth_captions(
            valid_loader.dataset) if sc_will_happen else None

    #########################################
    # Setup (optional) TensorBoardX logging #
    #########################################

    writer = None
    if args.tensorboard:
        if SummaryWriter is not None:
            model_name = get_model_name(args, params)
            timestamp = datetime.now().strftime('%Y%m%d%H%M%S')
            log_dir = os.path.join(
                args.output_root, 'log_tb/{}_{}'.format(model_name, timestamp))
            writer = SummaryWriter(log_dir=log_dir)
            print("INFO: Logging TensorBoardX events to {}".format(log_dir))
        else:
            print(
                "WARNING: SummaryWriter object not available. "
                "Hint: Please install TensorBoardX using pip install tensorboardx"
            )

    ######################
    # Build the model(s) #
    ######################

    # Set per parameter learning rate here, if supplied by the user:

    if args.lr_word_decoder is not None:
        if not params.hierarchical_model:
            print(
                "ERROR: Setting word decoder learning rate currently supported in Hierarchical Model only."
            )
            sys.exit(1)

        lr_dict = {'word_decoder': args.lr_word_decoder}
    else:
        lr_dict = {}

    model = EncoderDecoder(params,
                           device,
                           len(vocab),
                           state,
                           ef_dims,
                           lr_dict=lr_dict)

    ######################
    # Optimizer and loss #
    ######################

    sc_activated = False
    opt_params = model.get_opt_params()

    # Loss and optimizer
    if params.hierarchical_model:
        criterion = HierarchicalXEntropyLoss(
            weight_sentence_loss=params.weight_sentence_loss,
            weight_word_loss=params.weight_word_loss)
    elif args.share_embedding_weights:
        criterion = SharedEmbeddingXentropyLoss(param_lambda=0.15)
    else:
        criterion = nn.CrossEntropyLoss()

    if sc_will_happen:  # save it for later
        if args.self_critical_loss == 'sc':
            from model.loss import SelfCriticalLoss
            rl_criterion = SelfCriticalLoss()
        elif args.self_critical_loss == 'sc_with_diversity':
            from model.loss import SelfCriticalWithDiversityLoss
            rl_criterion = SelfCriticalWithDiversityLoss()
        elif args.self_critical_loss == 'sc_with_relative_diversity':
            from model.loss import SelfCriticalWithRelativeDiversityLoss
            rl_criterion = SelfCriticalWithRelativeDiversityLoss()
        elif args.self_critical_loss == 'sc_with_bleu_diversity':
            from model.loss import SelfCriticalWithBLEUDiversityLoss
            rl_criterion = SelfCriticalWithBLEUDiversityLoss()
        elif args.self_critical_loss == 'sc_with_repetition':
            from model.loss import SelfCriticalWithRepetitionLoss
            rl_criterion = SelfCriticalWithRepetitionLoss()
        elif args.self_critical_loss == 'mixed':
            from model.loss import MixedLoss
            rl_criterion = MixedLoss()
        elif args.self_critical_loss == 'mixed_with_face':
            from model.loss import MixedWithFACELoss
            rl_criterion = MixedWithFACELoss(vocab_size=len(vocab))
        elif args.self_critical_loss in [
                'sc_with_penalty', 'sc_with_penalty_throughout',
                'sc_masked_tokens'
        ]:
            raise ValueError('Deprecated loss, use \'sc\' loss')
        else:
            raise ValueError('Invalid self-critical loss')

        print('Selected self-critical loss is', rl_criterion)

        if start_epoch >= args.self_critical_from_epoch:
            criterion = rl_criterion
            sc_activated = True
            print('Self-critical loss training begins')

    # When using CyclicalLR, default learning rate should be always 1.0
    if args.lr_scheduler == 'CyclicalLR':
        default_lr = 1.
    else:
        default_lr = 0.001

    if sc_activated:
        optimizer = torch.optim.Adam(
            opt_params,
            lr=args.learning_rate if args.learning_rate else 5e-5,
            weight_decay=args.weight_decay)
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(opt_params,
                                     lr=default_lr,
                                     weight_decay=args.weight_decay)
    elif args.optimizer == 'rmsprop':
        optimizer = torch.optim.RMSprop(opt_params,
                                        lr=default_lr,
                                        weight_decay=args.weight_decay)
    elif args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(opt_params,
                                    lr=default_lr,
                                    weight_decay=args.weight_decay)
    else:
        print('ERROR: unknown optimizer:', args.optimizer)
        sys.exit(1)

    # We don't want to initialize the optimizer if we are transfering
    # the language model from the regular model to hierarchical model
    transfer_language_model = False

    if arg_params.hierarchical_model and state and not state.get(
            'hierarchical_model'):
        transfer_language_model = True

    # Set optimizer state to the one found in a loaded model, unless
    # we are doing a transfer learning step from flat to hierarchical model,
    # or we are using self-critical loss,
    # or the number of unique parameter groups has changed, or the user
    # has explicitly told us *not to* reuse optimizer parameters from before
    if state and not transfer_language_model and not sc_activated and not args.optimizer_reset:
        # Check that number of parameter groups is the same
        if len(optimizer.param_groups) == len(
                state['optimizer']['param_groups']):
            optimizer.load_state_dict(state['optimizer'])

    # override lr if set explicitly in arguments -
    # 1) Global learning rate:
    if args.learning_rate:
        for param_group in optimizer.param_groups:
            param_group['lr'] = args.learning_rate
        params.learning_rate = args.learning_rate
    else:
        params.learning_rate = default_lr

    # 2) Parameter-group specific learning rate:
    if args.lr_word_decoder is not None:
        # We want to give user an option to set learning rate for word_decoder
        # separately. Other exceptions can be added as needed:
        for param_group in optimizer.param_groups:
            if param_group.get('name') == 'word_decoder':
                param_group['lr'] = args.lr_word_decoder
                break

    if args.validate is not None and args.lr_scheduler == 'ReduceLROnPlateau':
        print('Using ReduceLROnPlateau learning rate scheduler')
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                               'min',
                                                               verbose=True,
                                                               patience=2)
    elif args.lr_scheduler == 'StepLR':
        print('Using StepLR learning rate scheduler with step_size {}'.format(
            args.lr_step_size))
        # Decrease the learning rate by the factor of gamma at every
        # step_size epochs (for example every 5 or 10 epochs):
        step_size = args.lr_step_size
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                    step_size,
                                                    gamma=0.5,
                                                    last_epoch=-1)
    elif args.lr_scheduler == 'CyclicalLR':
        print(
            "Using Cyclical learning rate scheduler, lr range: [{},{}]".format(
                args.lr_cyclical_min, args.lr_cyclical_max))

        step_size = len(data_loader)
        clr = cyclical_lr(step_size,
                          min_lr=args.lr_cyclical_min,
                          max_lr=args.lr_cyclical_max)
        n_groups = len(optimizer.param_groups)
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                      [clr] * n_groups)
    elif args.lr_scheduler is not None:
        print('ERROR: Invalid learing rate scheduler specified: {}'.format(
            args.lr_scheduler))
        sys.exit(1)

    ###################
    # Train the model #
    ###################

    stats_postfix = None
    if args.validate_only:
        stats_postfix = args.validate
    if args.load_model:
        all_stats = init_stats(args, params, postfix=stats_postfix)
    else:
        all_stats = {}

    if args.force_epoch:
        start_epoch = args.force_epoch - 1

    if not args.validate_only:
        total_step = len(data_loader)
        print(
            'Start training with start_epoch={:d} num_epochs={:d} num_batches={:d} ...'
            .format(start_epoch, args.num_epochs, args.num_batches))

    if args.teacher_forcing != 'always':
        print('\t k: {}'.format(args.teacher_forcing_k))
        print('\t beta: {}'.format(args.teacher_forcing_beta))
    print('Optimizer:', optimizer)

    if args.validate_only:
        stats = {}
        teacher_p = 1.0
        if args.teacher_forcing != 'always':
            print(
                'WARNING: teacher_forcing!=always, not yet implemented for --validate_only mode'
            )

        epoch = start_epoch - 1
        if str(epoch +
               1) in all_stats.keys() and args.skip_existing_validations:
            print('WARNING: epoch {} already validated, skipping...'.format(
                epoch + 1))
            return

        val_loss = do_validate(model, valid_loader, criterion, scorers, vocab,
                               teacher_p, args, params, stats, epoch,
                               sc_activated, gts_sc_valid)
        all_stats[str(epoch + 1)] = stats
        save_stats(args, params, all_stats, postfix=stats_postfix)
    else:
        for epoch in range(start_epoch, args.num_epochs):
            stats = {}
            begin = datetime.now()

            total_loss = 0

            if params.hierarchical_model:
                total_loss_sent = 0
                total_loss_word = 0

            num_batches = 0
            vocab_counts = {
                'cnt': 0,
                'max': 0,
                'min': 9999,
                'sum': 0,
                'unk_cnt': 0,
                'unk_sum': 0
            }

            # If start self critical training
            if not sc_activated and sc_will_happen and epoch >= args.self_critical_from_epoch:
                if all_stats:
                    best_ep, best_cider = max(
                        [(ep, all_stats[ep]['validation_cider'])
                         for ep in all_stats],
                        key=lambda x: x[1])
                    print('Loading model from epoch', best_ep,
                          'which has the better score with', best_cider)
                    state = torch.load(
                        get_model_path(args, params, int(best_ep)))
                    model = EncoderDecoder(params,
                                           device,
                                           len(vocab),
                                           state,
                                           ef_dims,
                                           lr_dict=lr_dict)
                    opt_params = model.get_opt_params()

                optimizer = torch.optim.Adam(opt_params,
                                             lr=5e-5,
                                             weight_decay=args.weight_decay)
                criterion = rl_criterion
                print('Self-critical loss training begins')
                sc_activated = True

            for i, data in enumerate(data_loader):

                if params.hierarchical_model:
                    (images, captions, lengths, image_ids, features,
                     sorting_order, last_sentence_indicator) = data
                    sorting_order = sorting_order.to(device)
                else:
                    (images, captions, lengths, image_ids, features) = data

                if epoch == 0:
                    unk = vocab('<unk>')
                    for j in range(captions.shape[0]):
                        # Flatten the caption in case it's a paragraph
                        # this is harmless for regular captions too:
                        xl = captions[j, :].view(-1)
                        xw = xl > unk
                        xu = xl == unk
                        xwi = sum(xw).item()
                        xui = sum(xu).item()
                        vocab_counts['cnt'] += 1
                        vocab_counts['sum'] += xwi
                        vocab_counts['max'] = max(vocab_counts['max'], xwi)
                        vocab_counts['min'] = min(vocab_counts['min'], xwi)
                        vocab_counts['unk_cnt'] += xui > 0
                        vocab_counts['unk_sum'] += xui
                # Set mini-batch dataset
                images = images.to(device)
                captions = captions.to(device)

                # Remove <start> token from targets if we are initializing the RNN
                # hidden state from image features:
                if params.rnn_hidden_init == 'from_features' and not params.hierarchical_model:
                    # Subtract one from all lengths to match new target lengths:
                    lengths = [x - 1 if x > 0 else x for x in lengths]
                    targets = pack_padded_sequence(captions[:, 1:],
                                                   lengths,
                                                   batch_first=True)[0]
                else:
                    if params.hierarchical_model:
                        targets = prepare_hierarchical_targets(
                            last_sentence_indicator, args.max_sentences,
                            lengths, captions, device)
                    else:
                        targets = pack_padded_sequence(captions,
                                                       lengths,
                                                       batch_first=True)[0]
                        sorting_order = None

                init_features = features[0].to(device) if len(
                    features) > 0 and features[0] is not None else None
                persist_features = features[1].to(device) if len(
                    features) > 1 and features[1] is not None else None

                # Forward, backward and optimize
                # Calculate the probability whether to use teacher forcing or not:

                # Iterate over batches:
                iteration = (epoch - start_epoch) * len(data_loader) + i

                teacher_p = get_teacher_prob(args.teacher_forcing_k, iteration,
                                             args.teacher_forcing_beta)

                # Allow model to log values at the last batch of the epoch
                writer_data = None
                if writer and (i == len(data_loader) - 1
                               or i == args.num_batches - 1):
                    writer_data = {'writer': writer, 'epoch': epoch + 1}

                sample_len = captions.size(1) if args.self_critical_loss in [
                    'mixed', 'mixed_with_face'
                ] else 20
                if sc_activated:
                    sampled_seq, sampled_log_probs, outputs = model.sample(
                        images,
                        init_features,
                        persist_features,
                        max_seq_length=sample_len,
                        start_token_id=vocab('<start>'),
                        trigram_penalty_alpha=args.trigram_penalty_alpha,
                        stochastic_sampling=True,
                        output_logprobs=True,
                        output_outputs=True)
                    sampled_seq = model.decoder.alt_prob_to_tensor(
                        sampled_seq, device=device)
                else:
                    outputs = model(images,
                                    init_features,
                                    captions,
                                    lengths,
                                    persist_features,
                                    teacher_p,
                                    args.teacher_forcing,
                                    sorting_order,
                                    writer_data=writer_data)

                if args.share_embedding_weights:
                    # Weights of (HxH) projection matrix used for regularizing
                    # models that share embedding weights
                    projection = model.decoder.projection.weight
                    loss = criterion(projection, outputs, targets)
                elif sc_activated:
                    # get greedy decoding baseline
                    model.eval()
                    with torch.no_grad():
                        greedy_sampled_seq = model.sample(
                            images,
                            init_features,
                            persist_features,
                            max_seq_length=sample_len,
                            start_token_id=vocab('<start>'),
                            trigram_penalty_alpha=args.trigram_penalty_alpha,
                            stochastic_sampling=False)
                        greedy_sampled_seq = model.decoder.alt_prob_to_tensor(
                            greedy_sampled_seq, device=device)
                    model.train()

                    if args.self_critical_loss in [
                            'sc', 'sc_with_diversity',
                            'sc_with_relative_diversity',
                            'sc_with_bleu_diversity', 'sc_with_repetition'
                    ]:
                        loss, advantage = criterion(
                            sampled_seq,
                            sampled_log_probs,
                            greedy_sampled_seq, [gts_sc[i] for i in image_ids],
                            scorers,
                            vocab,
                            return_advantage=True)
                    elif args.self_critical_loss in ['mixed']:
                        loss, advantage = criterion(
                            sampled_seq,
                            sampled_log_probs,
                            outputs,
                            greedy_sampled_seq, [gts_sc[i] for i in image_ids],
                            scorers,
                            vocab,
                            targets,
                            lengths,
                            gamma_ml_rl=args.gamma_ml_rl,
                            return_advantage=True)
                    elif args.self_critical_loss in ['mixed_with_face']:
                        loss, advantage = criterion(
                            sampled_seq,
                            sampled_log_probs,
                            outputs,
                            greedy_sampled_seq, [gts_sc[i] for i in image_ids],
                            scorers,
                            vocab,
                            captions,
                            targets,
                            lengths,
                            gamma_ml_rl=args.gamma_ml_rl,
                            return_advantage=True)
                    else:
                        raise ValueError('Invalid self-critical loss')

                    if writer is not None and i % 100 == 0:
                        writer.add_scalar('training_loss', loss.item(),
                                          epoch * len(data_loader) + i)
                        writer.add_scalar('advantage', advantage,
                                          epoch * len(data_loader) + i)
                        writer.add_scalar('lr',
                                          optimizer.param_groups[0]['lr'],
                                          epoch * len(data_loader) + i)
                else:
                    loss = criterion(outputs, targets)

                model.zero_grad()
                loss.backward()

                # Clip gradients if desired:
                if args.grad_clip is not None:
                    # grad_norms = [x.grad.data.norm(2) for x in opt_params]
                    # batch_max_grad = np.max(grad_norms)
                    # if batch_max_grad > 10.0:
                    #     print('WARNING: gradient norms larger than 10.0')

                    # torch.nn.utils.clip_grad_norm_(decoder.parameters(), 0.1)
                    # torch.nn.utils.clip_grad_norm_(encoder.parameters(), 0.1)
                    clip_gradients(optimizer, args.grad_clip)

                # Update weights:
                optimizer.step()

                # CyclicalLR requires us to update LR at every minibatch:
                if args.lr_scheduler == 'CyclicalLR':
                    scheduler.step()

                total_loss += loss.item()

                num_batches += 1

                if params.hierarchical_model:
                    _, loss_sent, _, loss_word = criterion.item_terms()
                    total_loss_sent += float(loss_sent)
                    total_loss_word += float(loss_word)

                # Print log info
                if (i + 1) % args.log_step == 0:
                    print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, '
                          'Perplexity: {:5.4f}'.format(epoch + 1,
                                                       args.num_epochs, i + 1,
                                                       total_step, loss.item(),
                                                       np.exp(loss.item())))
                    sys.stdout.flush()

                    if params.hierarchical_model:
                        weight_sent, loss_sent, weight_word, loss_word = criterion.item_terms(
                        )
                        print('Sentence Loss: {:.4f}, '
                              'Word Loss: {:.4f}'.format(
                                  float(loss_sent), float(loss_word)))
                        sys.stdout.flush()

                if i + 1 == args.num_batches:
                    break

            end = datetime.now()

            stats['training_loss'] = total_loss / num_batches

            if params.hierarchical_model:
                stats['loss_sentence'] = total_loss_sent / num_batches
                stats['loss_word'] = total_loss_word / num_batches

            print('Epoch {} duration: {}, average loss: {:.4f}'.format(
                epoch + 1, end - begin, stats['training_loss']))

            save_model(args, params, model.encoder, model.decoder, optimizer,
                       epoch, vocab)

            if epoch == 0:
                vocab_counts['avg'] = vocab_counts['sum'] / vocab_counts['cnt']
                vocab_counts['unk_cnt_per'] = 100 * vocab_counts[
                    'unk_cnt'] / vocab_counts['cnt']
                vocab_counts['unk_sum_per'] = 100 * vocab_counts[
                    'unk_sum'] / vocab_counts['sum']
                # print(vocab_counts)
                print((
                    'Training data contains {sum} words in {cnt} captions (avg. {avg:.1f} w/c)'
                    + ' with {unk_sum} <unk>s ({unk_sum_per:.1f}%)' +
                    ' in {unk_cnt} ({unk_cnt_per:.1f}%) captions').format(
                        **vocab_counts))

            ############################################
            # Validation loss and learning rate update #
            ############################################

            if args.validate is not None and (epoch +
                                              1) % args.validation_step == 0:
                val_loss = do_validate(model, valid_loader, criterion, scorers,
                                       vocab, teacher_p, args, params, stats,
                                       epoch, sc_activated, gts_sc_valid)

                if args.lr_scheduler == 'ReduceLROnPlateau':
                    scheduler.step(val_loss)
            elif args.lr_scheduler == 'StepLR':
                scheduler.step()

            all_stats[str(epoch + 1)] = stats
            save_stats(args, params, all_stats, writer=writer)

            if writer is not None:
                # Log model data to tensorboard
                log_model_data(params, model, epoch + 1, writer)

    if writer is not None:
        writer.close()
Example #16
0
class infer_object:
    def __init__(self, args):
        # print('__init__() :', args)

        if 'cpu' not in args:
            args['cpu'] = False
        if 'vocab' not in args:
            args['vocab'] = None
        if 'ext_features' not in args:
            args['ext_features'] = None
        if 'ext_persist_features' not in args:
            args['ext_persist_features'] = None
        if 'lemma_pos_rules' not in args:
            args['lemma_pos_rules'] = None
        
        global device
        device = torch.device('cuda' if torch.cuda.is_available() and
                              not args['cpu'] else 'cpu')

        vi = sys.version_info
        print('Python version {}.{}.{}, torch version {}'.format(vi[0], vi[1], vi[2],
                                                                 torch.__version__))
        
        # Build models
        print('Bulding model for device {}'.format(device.type))

        try:
            self.state = torch.load(args['model'], map_location=device)
        except AttributeError:
            print('WARNING: Old model found. '+
                  'Please use model_update.py in the model before executing this script.')
            exit(1)
            
        self.params = ModelParams(self.state)
        if args['ext_features']:
            self.params.update_ext_features(args['ext_features'])
        if args['ext_persist_features']:
            self.params.update_ext_persist_features(args['ext_persist_features'])

        print('Loaded model parameters from <{}>:'.format(args['model']))
        print(self.params)

        # Load the vocabulary:
        if args['vocab'] is not None:
            # Loading vocabulary from file path supplied by the user:
            args_attr = AttributeDict(args)
            self.vocab = get_vocab(args_attr)
        elif self.params.vocab is not None:
            print('Loading vocabulary stored in the model file.')
            self.vocab = self.params.vocab
        else:
            print('ERROR: you must either load a model that contains vocabulary or '
                  'specify a vocabulary with the --vocab option!')
            sys.exit(1)

            print('Size of the vocabulary is {}'.format(len(self.vocab)))

        ef_dims = None
        self.model = EncoderDecoder(self.params, device, len(self.vocab), self.state, ef_dims).eval()
        # print(self.params.ext_features_dim)

        self.lemma_pos_rules = {}
        self.pos_names = set()
        if args['lemma_pos_rules'] is not None:
            self.read_lemma_pos_rules(args['lemma_pos_rules'])
            
    def external_features(self):
        ef_dims = self.params.ext_features_dim
        ret = [(self.state['features'].external,         ef_dims[0]),
               (self.state['persist_features'].external, ef_dims[1])]
        # print(ret)
        return ret


    def read_lemma_pos_rules(self, file):
        # print('Reading lemma_pos_rules from <{}>'.format(file))
        with open(file) as fp:
            n = 0
            for line in fp:
                ll = line.split(' ')
                l, p, w = ll[:3]
                # print (l, p, w)
                if p not in self.lemma_pos_rules:
                    self.lemma_pos_rules[p] = {}
                assert l not in self.lemma_pos_rules[p], 'l in lemma_pos_rules[p]'
                self.lemma_pos_rules[p][l] = w
                n += 1
            self.pos_names = self.lemma_pos_rules.keys()
            print('Read {} lemma_pos_rules for {} pos from <{}>'.
                  format(n, len(self.pos_names), file))
        assert len(self.pos_names)>0, 'failed reading lemma_pos_rules from <{}>'.format(file)
            

    def apply_lemma_pos_rules(self, caption):
        if len(self.pos_names)==0:
            return caption
        
        w = caption.split(' ')
        x = []
        for i in range(len(w)):
            if w[i] in self.pos_names:
                continue
            if i+1<len(w) and w[i+1] in self.pos_names:
                if w[i] in self.lemma_pos_rules[w[i+1]]:
                    x.append(self.lemma_pos_rules[w[i+1]][w[i]])
                else:
                    x.append(w[i])
            else:
                x.append(w[i])

        return ' '.join(x)


    def infer(self, args):
        # print('infer() :', args)

        if 'image_features' not in args:
            args['image_features'] = None

        # Image preprocessing
        transform = transforms.Compose([
            transforms.Resize((args['resize'], args['resize'])),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406),
                                 (0.229, 0.224, 0.225))])

        # Get dataset parameters:
        dataset_configs = DatasetParams(args['dataset_config_file'])
        dataset_params = dataset_configs.get_params(args['dataset'],
                                                    args['image_dir'],
                                                    args['image_files'],
                                                    args['image_features'])

        if self.params.has_external_features() and \
           any(dc.name == 'generic' for dc in dataset_params):
            print('WARNING: you cannot use external features without specifying all datasets in '
                  'datasets.conf.')
            print('Hint: take a look at datasets/datasets.conf.default.')

        # Build data loader
        print("Loading dataset: {}".format(args['dataset']))

        # Update dataset params with needed model params:
        for i in dataset_params:
            i.config_dict['skip_start_token'] = self.params.skip_start_token
            # For visualizing attention we need file names instead of IDs in our output:
            if args['store_image_paths']:
                i.config_dict['return_image_file_name'] = True

        ext_feature_sets = [self.params.features.external, self.params.persist_features.external]
        if args['dataset']=='incore':
            ext_feature_sets = None
        
        # We ask it to iterate over images instead of all (image, caption) pairs
        data_loader, ef_dims = get_loader(dataset_params, vocab=None, transform=transform,
                                          batch_size=args['batch_size'], shuffle=False,
                                          num_workers=args['num_workers'],
                                          ext_feature_sets=ext_feature_sets,
                                          skip_images=not self.params.has_internal_features(),
                                          iter_over_images=True)

        self.data_loader = data_loader
        # Create model directory
        if not os.path.exists(args['results_path']):
            os.makedirs(args['results_path'])

        scorers = {}
        if args['scoring'] is not None:
            for s in args['scoring'].split(','):
                s = s.lower().strip()
                if s == 'cider':
                    from eval.cider import Cider
                    scorers['CIDEr'] = Cider(df='corpus')

        # Store captions here:
        output_data = []

        gts = {}
        res = {}

        print('Starting inference, max sentence length: {} num_workers: {}'.\
              format(args['max_seq_length'], args['num_workers']))
        show_progress = sys.stderr.isatty() and not args['verbose'] \
                        and ext_feature_sets is not None

        for i, (images, ref_captions, lengths, image_ids,
                features) in enumerate(tqdm(self.data_loader, disable=not show_progress)):

            if len(scorers) > 0:
                for j in range(len(ref_captions)):
                    jid = image_ids[j]
                    if jid not in gts:
                        gts[jid] = []
                    rcs = ref_captions[j]
                    if type(rcs) is str:
                        rcs = [rcs]
                    for rc in rcs:
                        gts[jid].append(rc.lower())

            images = images.to(device)

            init_features    = features[0].to(device) if len(features) > 0 and \
                               features[0] is not None else None
            persist_features = features[1].to(device) if len(features) > 1 and \
                               features[1] is not None else None

            # Generate a caption from the image
            sampled_batch = self.model.sample(images, init_features, persist_features,
                                              max_seq_length=args['max_seq_length'],
                                              start_token_id=self.vocab('<start>'),
                                              end_token_id=self.vocab('<end>'),
                                              alternatives=args['alternatives'],
                                              probabilities=args['probabilities'])

            sampled_ids_batch = sampled_batch

            for i in range(len(sampled_ids_batch)):
                sampled_ids = sampled_ids_batch[i]

                # Convert word_ids to words
                if self.params.hierarchical_model:
                    # assert False, 'paragraph_ids_to_words() need to be updated'
                    caption = paragraph_ids_to_words(sampled_ids, self.vocab,
                                                     skip_start_token=True)
                else:
                    caption = caption_ids_ext_to_words(sampled_ids, self.vocab,
                                                       skip_start_token=True,
                                                       capitalize=not args['no_capitalize'])
                if args['no_repeat_sentences']:
                    caption = remove_duplicate_sentences(caption)

                if args['only_complete_sentences']:
                    caption = remove_incomplete_sentences(caption)

                if args['verbose']:
                    print('=>', caption)

                if True:
                    caption = self.apply_lemma_pos_rules(caption)
                    if args['verbose']:
                        print('#>', caption)
                    
                output_data.append({'image_id': image_ids[i],
                                    'caption': caption})
                res[image_ids[i]] = [caption.lower()]

        for score_name, scorer in scorers.items():
            score = scorer.compute_score(gts, res)[0]
            print('Test', score_name, score)

        # Decide output format, fall back to txt
        if args['output_format'] is not None:
            output_format = args['output_format']
        elif args['output_file'] and args['output_file'].endswith('.json'):
            output_format = 'json'
        else:
            output_format = 'txt'

        # Create a sensible default output path for results:
        output_file = None
        if not args['output_file'] and not args['print_results']:
            model_name_path = Path(args['model'])
            is_in_same_folder = len(model_name_path.parents) == 1
            if not is_in_same_folder:
                model_name = args['model'].split(os.sep)[-2]
                model_epoch = basename(args['model'])
                output_file = '{}-{}.{}'.format(model_name, model_epoch,
                                                output_format)
            else:
                output_file = model_name_path.stem + '.' + output_format
        else:
            output_file = args['output_file']

        if output_file:
            output_path = os.path.join(args['results_path'], output_file)
            if output_format == 'json':
                json.dump(output_data, open(output_path, 'w'))
            else:
                with open(output_path, 'w') as fp:
                    for data in output_data:
                        print(data['image_id'], data['caption'], file=fp)

            print('Wrote generated captions to {} as {}'.
                  format(output_path, args['output_format']))

        if args['print_results']:
            for d in output_data:
                print('{}: {}'.format(d['image_id'], d['caption']))

        return output_data
Example #17
0
def main(model_name, use_cuda, batch_size, teacher_forcing_schedule, keep_prob, val_size, lr, decoder_type, vocab_limit, hidden_size, embedding_size, max_length, main_data, test_data, device, seed=42):
    print("Max Length is: ", max_length)
    model_path = './model/' + model_name + '/'

    print("training %s with use_cuda=%s, batch_size=%i"% (model_name, use_cuda, batch_size), flush=True)
    print("teacher_forcing_schedule=", teacher_forcing_schedule, flush=True)
    print("keep_prob=%f, val_size=%f, lr=%f, decoder_type=%s, vocab_limit=%i, hidden_size=%i, embedding_size=%i, max_length=%i, seed=%i" % (keep_prob, val_size, lr, decoder_type, vocab_limit, hidden_size, embedding_size, max_length, seed), flush=True)

    train_src, train_tgt, val_src, val_tgt = load_split_eighty_twenty(main_data, seed)

    if test_data:
        test_src, test_tgt = load_complete_data(test_data)
    
    if os.path.isdir(model_path):

        print("loading encoder and decoder from model_path", flush=True)
        encoder_decoder = torch.load(model_path + model_name + '_final.pt')

        print("creating training, validation, and testing datasets with saved languages", flush=True)
        train_dataset = SequencePairDataset(train_src, train_tgt,
                                            lang=encoder_decoder.lang,
                                            use_extended_vocab=(encoder_decoder.decoder_type=='copy'))

        val_dataset = SequencePairDataset(val_src, val_tgt,
                                          lang=encoder_decoder.lang,
                                          use_extended_vocab=(encoder_decoder.decoder_type=='copy'))

        test_dataset = SequencePairDataset(test_src, test_tgt,
                                            lang=train_dataset.lang,
                                            use_extended_vocab=(encoder_decoder.decoder_type=='copy'))

    else:
        os.mkdir(model_path)

        print("creating training, validation, and testing datasets", flush=True)
        train_dataset = SequencePairDataset(train_src, train_tgt,
                                            vocab_limit=vocab_limit,
                                            use_extended_vocab=(decoder_type=='copy'))

        val_dataset = SequencePairDataset(val_src, val_tgt,
                                          lang=train_dataset.lang,
                                          use_extended_vocab=(decoder_type=='copy'))

        test_dataset = SequencePairDataset(test_src, test_tgt,
                                            lang=train_dataset.lang,
                                            use_extended_vocab=(decoder_type=='copy'))

        print("creating encoder-decoder model", flush=True)
        encoder_decoder = EncoderDecoder(train_dataset.lang,
                                         max_length,
                                         embedding_size,
                                         hidden_size,
                                         decoder_type,
                                         device)

        torch.save(encoder_decoder, model_path + '/%s.pt' % model_name)

    encoder_decoder = encoder_decoder.to(device)

    train_data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=12)
    val_data_loader = DataLoader(val_dataset, batch_size=batch_size)
    test_data_loader = DataLoader(test_dataset, batch_size=batch_size)

    mixed_data_loader = DataLoader(mixed_dataset, batch_size=batch_size)

    trained_model = train(encoder_decoder,
          train_data_loader,
          model_name,
          val_data_loader,
          keep_prob,
          teacher_forcing_schedule,
          lr,
          encoder_decoder.decoder.max_length,
          device)

    trained_model = torch.load(model_path + model_name + '_final.pt')

    # Write final model errors to an output log
    if test_data:
        f = open("./logs/log_" + model_name + ".txt", "w")
        f.write("MODEL {}\n\n".format(model_name))
        f.write("UNSEEN ACCURACY\n")
        with torch.no_grad():
            accuracy = test(trained_model, test_data_loader, encoder_decoder.decoder.max_length, device, log_file=f)
        f.close()