def __init__(self, hp, save_dir):
        super().__init__(hp, save_dir)

        self.end_epoch_loader = None  # TODO: not generating yet, need to refactor that

        # Model
        # Load text embeddings
        # TODO: move this into some config file
        vocab_size = len(
            utils.load_file(LABELED_PROGRESSION_PAIRS_IDX2TOKEN_PATH))
        strokes_to_instruction_fp = 'best_models/strokes_to_instruction/catsdecoder-dim_512-model_type_cnn_lstm-use_prestrokes_False/model.pt'
        weights = torch.load(strokes_to_instruction_fp)
        enc_dim = weights['token_embedding.weight'].size(1)
        # enc_dim = hp.enc_dim
        # self.text_embedding = nn.Embedding(vocab_size, hp.enc_dim)
        self.text_embedding = nn.Embedding(
            vocab_size, enc_dim)  # if we're loading, must be the same size
        self.text_embedding.weight = nn.Parameter(
            weights['token_embedding.weight'])

        self.enc = InstructionEncoderTransformer(
            enc_dim, hp.enc_num_layers, hp.dropout,
            use_categories=False)  # TODO: should this be a hparam
        dec_input_dim = 5 if (hp.cond_instructions == 'initdec') else (
            5 + enc_dim)  # dec_inputs
        self.dec = SketchRNNDecoderGMM(
            dec_input_dim, hp.dec_dim,
            hp.M)  # Method 1 (see one_forward_pass, i.e. decinputs)

        self.models.extend([self.text_embedding, self.enc, self.dec])
        if USE_CUDA:
            for model in self.models:
                model.cuda()

        self.optimizers.append(optim.Adam(self.parameters(), hp.lr))
示例#2
0
    def __init__(self, hp, save_dir, skip_data=False):
        super().__init__(hp, save_dir, skip_data=skip_data)

        # Model
        self.enc = SketchRNNVAEEncoder(5, hp.enc_dim, hp.enc_num_layers, hp.z_dim,dropout=hp.dropout,
                                       use_layer_norm=hp.use_layer_norm, rec_dropout=hp.rec_dropout)
                                    #    use_layer_norm=False)
        self.fc_z_to_hc = nn.Linear(hp.z_dim, 2 * hp.dec_dim)  # 2: 1 for hidden, 1 for cell

        # Model
        inp_dim = 5 + hp.z_dim
        self.category_embedding = None
        if hp.use_categories_dec:
            self.category_embedding = nn.Embedding(35, self.hp.categories_dim)
            self.models.append(self.category_embedding)
            inp_dim += hp.categories_dim
        self.dec = SketchRNNDecoderGMM(inp_dim, hp.dec_dim, hp.M, dropout=hp.dropout,
            use_layer_norm=self.hp.use_layer_norm, rec_dropout=self.hp.rec_dropout)

        self.models.extend([self.enc, self.fc_z_to_hc, self.dec])
        if USE_CUDA:
            for model in self.models:
                model.cuda()

        # optimization -- ADAM plus annealing (supp eq. 4)
        self.optimizers.append(optim.Adam(self.parameters(), hp.lr))
示例#3
0
class SketchRNNDecoderGMMOnlyModel(SketchRNNModel):
    def __init__(self, hp, save_dir, skip_data=False):
        super().__init__(hp, save_dir, skip_data=skip_data)

        # Model
        self.category_embedding = None
        if hp.use_categories_dec:
            self.category_embedding = nn.Embedding(35, self.hp.categories_dim)
            self.models.append(self.category_embedding)
        inp_dim = (5 + hp.categories_dim) if self.category_embedding else 5
        self.dec = SketchRNNDecoderGMM(inp_dim, hp.dec_dim, hp.M, dropout=self.hp.dropout,
            use_layer_norm=self.hp.use_layer_norm, rec_dropout=self.hp.rec_dropout)
        self.models.append(self.dec)
        if USE_CUDA:
            for model in self.models:
                model.cuda()

        # optimization -- ADAM plus annealing (supp eq. 4)
        self.optimizers.append(optim.Adam(self.parameters(), hp.lr))

    def one_forward_pass(self, batch):
        """
        Return loss and other items of interest for one forward pass

        Args:
            batch: tuple from DataLoaders

        Returns:
            dict where 'loss': float Tensor must exist
        """
        strokes, stroke_lens, cats, cats_idx = batch
        max_len, bsz, _ = strokes.size()

        # Create inputs to decoder
        sos = torch.stack([torch.Tensor([0, 0, 1, 0, 0])] * bsz).unsqueeze(0)  # start of sequence
        sos = nn_utils.move_to_cuda(sos)
        dec_inputs = torch.cat([sos, strokes], 0)  # add sos at the begining of the strokes; [max_len + 1, bsz, 5]

        if self.hp.use_categories_dec:
            cat_embs = self.category_embedding(cats_idx)  # [bsz, cat_dim]
            cat_embs = cat_embs.repeat(dec_inputs.size(0), 1, 1)  # [max_len + 1, bsz, cat_dim]
            dec_inputs = torch.cat([dec_inputs, cat_embs], dim=2)  # [max_len+1, bsz, 5 + cat_dim]

        # Decode
        outputs, pi, mu_x, mu_y, sigma_x, sigma_y, rho_xy, q, _, _ = self.dec(dec_inputs, output_all=True)

        # Calculate losses
        mask, dx, dy, p = self.dec.make_target(strokes, stroke_lens, self.hp.M)

        loss = self.dec.reconstruction_loss(mask,
                                            dx, dy, p,
                                            pi, mu_x, mu_y, sigma_x, sigma_y, rho_xy,
                                            q)
        result = {'loss': loss, 'loss_R': loss}

        if ((loss != loss).any() or (loss == float('inf')).any() or (loss == float('-inf')).any()):
            raise Exception('Nan in SketchRNnDecoderGMMOnly forward pass')

        return result
示例#4
0
    def __init__(self, hp, save_dir, skip_data=False):
        super().__init__(hp, save_dir, skip_data=skip_data)

        # Model
        self.category_embedding = None
        if hp.use_categories_dec:
            self.category_embedding = nn.Embedding(35, self.hp.categories_dim)
            self.models.append(self.category_embedding)
        inp_dim = (5 + hp.categories_dim) if self.category_embedding else 5
        self.dec = SketchRNNDecoderGMM(inp_dim, hp.dec_dim, hp.M, dropout=self.hp.dropout,
            use_layer_norm=self.hp.use_layer_norm, rec_dropout=self.hp.rec_dropout)
        self.models.append(self.dec)
        if USE_CUDA:
            for model in self.models:
                model.cuda()

        # optimization -- ADAM plus annealing (supp eq. 4)
        self.optimizers.append(optim.Adam(self.parameters(), hp.lr))
class InstructionToStrokesModel(SketchRNNModel):
    """"
    SketchRNN that also encodes and conditions on top-level instruction (i.e. instruction for entire
    drawing) generated by an instruction generation model.
    """
    def __init__(self, hp, save_dir):
        super().__init__(hp, save_dir)

        self.end_epoch_loader = None  # TODO: not generating yet, need to refactor that

        # Model
        # Load text embeddings
        # TODO: move this into some config file
        vocab_size = len(
            utils.load_file(LABELED_PROGRESSION_PAIRS_IDX2TOKEN_PATH))
        strokes_to_instruction_fp = 'best_models/strokes_to_instruction/catsdecoder-dim_512-model_type_cnn_lstm-use_prestrokes_False/model.pt'
        weights = torch.load(strokes_to_instruction_fp)
        enc_dim = weights['token_embedding.weight'].size(1)
        # enc_dim = hp.enc_dim
        # self.text_embedding = nn.Embedding(vocab_size, hp.enc_dim)
        self.text_embedding = nn.Embedding(
            vocab_size, enc_dim)  # if we're loading, must be the same size
        self.text_embedding.weight = nn.Parameter(
            weights['token_embedding.weight'])

        self.enc = InstructionEncoderTransformer(
            enc_dim, hp.enc_num_layers, hp.dropout,
            use_categories=False)  # TODO: should this be a hparam
        dec_input_dim = 5 if (hp.cond_instructions == 'initdec') else (
            5 + enc_dim)  # dec_inputs
        self.dec = SketchRNNDecoderGMM(
            dec_input_dim, hp.dec_dim,
            hp.M)  # Method 1 (see one_forward_pass, i.e. decinputs)

        self.models.extend([self.text_embedding, self.enc, self.dec])
        if USE_CUDA:
            for model in self.models:
                model.cuda()

        self.optimizers.append(optim.Adam(self.parameters(), hp.lr))

    def get_data_loader(self, dataset_split, batch_size, categories, max_len,
                        max_per_category, shuffle):
        """
        Args:
            dataset_split (str): 'train', 'valid', 'test'
            batch_size (int)
            categories (str
            shuffle (bool)
        """
        ds = ProgressionPairDataset(dataset_split)
        loader = DataLoader(ds,
                            batch_size=batch_size,
                            shuffle=shuffle,
                            collate_fn=ProgressionPairDataset.collate_fn)
        return loader

    def preprocess_batch_from_data_loader(self, batch):
        """
        Convert tensors to cuda and convert to [len, bsz, ...] instead of [bsz, len, ...]
        """
        preprocessed = []
        for item in batch:
            if type(item) == torch.Tensor:
                item = nn_utils.move_to_cuda(item)
                if item.dim() > 1:
                    item.transpose_(0, 1)
            preprocessed.append(item)
        return preprocessed

    def one_forward_pass(self, batch, average_loss=True):
        """
        Return loss and other items of interest for one forward pass

        Args:
            batch: tuple from DataLoaders
            average_loss (bool): whether to average loss per batch item
                - Current use case: Segmentation model computes loss per segment. Batches
                are a batch of segments for one example.

        Returns:
            dict where 'loss': float Tensor must exist
        """
        strokes, stroke_lens, texts, text_lens, text_indices, cats, cats_idx, urls = batch

        # Create base inputs to decoder
        _, bsz, _ = strokes.size()
        sos = torch.stack([torch.Tensor([0, 0, 1, 0, 0])] * bsz).unsqueeze(
            0)  # start of sequence
        sos = nn_utils.move_to_cuda(sos)
        dec_inputs = torch.cat(
            [sos, strokes], dim=0
        )  # add sos at the begining of the strokes; [max_len + 1, bsz, 5]

        #
        # Encode instructions, decode
        # text_indices: [len, bsz], text_lens: [bsz]
        hidden = self.enc(text_indices,
                          text_lens,
                          self.text_embedding,
                          category_embedding=None,
                          categories=cats_idx)  # [bsz, dim]

        # Method 1: concatenate instruction embedding to every time step
        if self.hp.cond_instructions == 'decinputs':
            hidden = hidden.unsqueeze(0)  #  [1, bsz, dim]
            hidden = hidden.repeat(dec_inputs.size(0), 1,
                                   1)  # [max_len + 1, bsz, dim]
            dec_inputs = torch.cat([dec_inputs, hidden],
                                   dim=2)  # [max_len + 1, bsz, 5 + dim]
            outputs, pi, mu_x, mu_y, sigma_x, sigma_y, rho_xy, q, _, _ = self.dec(
                dec_inputs, output_all=True)

        # Method 2: initialize decoder's hidden state with instruction embedding
        elif self.hp.cond_instructions == 'initdec':
            hidden = hidden.unsqueeze(0)  #  [1, bsz, dim]
            hidden_cell = (hidden, hidden.clone())
            outputs, pi, mu_x, mu_y, sigma_x, sigma_y, rho_xy, q, _, _ = self.dec(
                dec_inputs, output_all=True, hidden_cell=hidden_cell)

        #
        # Calculate losses
        #
        mask, dx, dy, p = self.dec.make_target(strokes, stroke_lens, self.hp.M)
        loss = self.dec.reconstruction_loss(mask,
                                            dx,
                                            dy,
                                            p,
                                            pi,
                                            mu_x,
                                            mu_y,
                                            sigma_x,
                                            sigma_y,
                                            rho_xy,
                                            q,
                                            average_loss=average_loss)
        result = {'loss': loss, 'loss_R': loss}

        if ((loss != loss).any() or (loss == float('inf')).any()
                or (loss == float('-inf')).any()):
            raise Exception('Nan in SketchRNnDecoderGMMOnly forward pass')

        return result

    def generate_and_save(self,
                          data_loader,
                          epoch,
                          n_gens=1,
                          outputs_path=None):
        # TODO: need to overwrite this. SketchRNN's generate_and_save() unpacks a batch from
        # a dataset that returns 4 values. The SketchWithPlansDataset returns 8 values. The
        # encoding of the instructions is different too. Need to refactor that bit to work
        # with both.
        pass

    def inference_pass(self, batch_of_segs, seg_lens, cats_idx):
        pass
    def __init__(self, hp, save_dir):
        super().__init__(hp, save_dir, skip_data=True)

        self.tr_loader = self.get_data_loader('train', True,
                                              self.hp.batch_size)
        self.val_loader = self.get_data_loader('valid', False,
                                               self.hp.batch_size)
        self.end_epoch_loader = self.get_data_loader('valid',
                                                     False,
                                                     batch_size=1)

        #
        # Model
        #

        # text and category embeddings
        self.text_embedding = nn.Embedding(self.tr_loader.dataset.vocab_size,
                                           hp.enc_dim)
        self.models.append(self.text_embedding)
        self.category_embedding = None
        if hp.use_categories_dec:
            self.category_embedding = nn.Embedding(35, hp.categories_dim)
            self.models.append(self.category_embedding)

        # decoder
        dec_input_dim = 5
        if self.category_embedding:
            dec_input_dim += hp.categories_dim
        if hp.cond_instructions == 'decinputs':
            dec_input_dim += hp.enc_dim
        self.dec = SketchRNNDecoderGMM(dec_input_dim,
                                       hp.dec_dim,
                                       hp.M,
                                       dropout=hp.dropout,
                                       use_layer_norm=hp.use_layer_norm,
                                       rec_dropout=hp.rec_dropout)
        self.models.append(self.dec)

        # For shaping representation loss, want to decode into instructions
        if hp.cond_instructions == 'decinputs':
            self.enc = InstructionEncoderTransformer(
                hp.enc_dim,
                hp.enc_num_layers,
                hp.dropout,
                use_categories=True,
                categories_dim=hp.categories_dim)
            self.models.append(self.enc)
        elif hp.cond_instructions == 'match':
            if hp.loss_match == 'triplet':
                # TODO: use category embedding? (see decinputs)
                self.enc = InstructionEncoderTransformer(
                    hp.enc_dim,
                    hp.enc_num_layers,
                    hp.dropout,
                    use_categories=False)  # TODO: should this be a hparam
                self.fc_dec = nn.Linear(
                    hp.dec_dim, hp.enc_dim
                )  # project decoder hidden states to enc hidden states
                self.models.extend([self.enc, self.fc_dec])
            elif hp.loss_match == 'decode':
                ins_hid_dim = hp.enc_dim  # Note: this has to be because there's no fc_out layer. I just multiply by token embedding directly to get outputs
                self.ins_dec = InstructionDecoderLSTM(hp.enc_dim +
                                                      hp.categories_dim,
                                                      ins_hid_dim,
                                                      num_layers=4,
                                                      dropout=0.1,
                                                      batch_first=False,
                                                      condition_on_hc=False,
                                                      use_categories=True)
                self.models.append(self.ins_dec)

        if USE_CUDA:
            for model in self.models:
                model.cuda()

        self.optimizers.append(optim.Adam(self.parameters(), hp.lr))
class SketchRNNWithPlans(SketchRNNModel):
    """"
    SketchRNN that also encodes and conditions on top-level instruction (i.e. instruction for entire
    drawing) generated by an instruction generation model.
    """
    def __init__(self, hp, save_dir):
        super().__init__(hp, save_dir, skip_data=True)

        self.tr_loader = self.get_data_loader('train', True,
                                              self.hp.batch_size)
        self.val_loader = self.get_data_loader('valid', False,
                                               self.hp.batch_size)
        self.end_epoch_loader = self.get_data_loader('valid',
                                                     False,
                                                     batch_size=1)

        #
        # Model
        #

        # text and category embeddings
        self.text_embedding = nn.Embedding(self.tr_loader.dataset.vocab_size,
                                           hp.enc_dim)
        self.models.append(self.text_embedding)
        self.category_embedding = None
        if hp.use_categories_dec:
            self.category_embedding = nn.Embedding(35, hp.categories_dim)
            self.models.append(self.category_embedding)

        # decoder
        dec_input_dim = 5
        if self.category_embedding:
            dec_input_dim += hp.categories_dim
        if hp.cond_instructions == 'decinputs':
            dec_input_dim += hp.enc_dim
        self.dec = SketchRNNDecoderGMM(dec_input_dim,
                                       hp.dec_dim,
                                       hp.M,
                                       dropout=hp.dropout,
                                       use_layer_norm=hp.use_layer_norm,
                                       rec_dropout=hp.rec_dropout)
        self.models.append(self.dec)

        # For shaping representation loss, want to decode into instructions
        if hp.cond_instructions == 'decinputs':
            self.enc = InstructionEncoderTransformer(
                hp.enc_dim,
                hp.enc_num_layers,
                hp.dropout,
                use_categories=True,
                categories_dim=hp.categories_dim)
            self.models.append(self.enc)
        elif hp.cond_instructions == 'match':
            if hp.loss_match == 'triplet':
                # TODO: use category embedding? (see decinputs)
                self.enc = InstructionEncoderTransformer(
                    hp.enc_dim,
                    hp.enc_num_layers,
                    hp.dropout,
                    use_categories=False)  # TODO: should this be a hparam
                self.fc_dec = nn.Linear(
                    hp.dec_dim, hp.enc_dim
                )  # project decoder hidden states to enc hidden states
                self.models.extend([self.enc, self.fc_dec])
            elif hp.loss_match == 'decode':
                ins_hid_dim = hp.enc_dim  # Note: this has to be because there's no fc_out layer. I just multiply by token embedding directly to get outputs
                self.ins_dec = InstructionDecoderLSTM(hp.enc_dim +
                                                      hp.categories_dim,
                                                      ins_hid_dim,
                                                      num_layers=4,
                                                      dropout=0.1,
                                                      batch_first=False,
                                                      condition_on_hc=False,
                                                      use_categories=True)
                self.models.append(self.ins_dec)

        if USE_CUDA:
            for model in self.models:
                model.cuda()

        self.optimizers.append(optim.Adam(self.parameters(), hp.lr))

    def get_data_loader(self, dataset_split, shuffle, batch_size):
        """
        Args:
            dataset_split (str): 'train', 'valid', 'test'
            shuffle (bool)
        """
        if self.hp.instruction_set in [
                'toplevel', 'toplevel_leaves', 'leaves'
        ]:
            ds = SketchWithPlansConditionEntireDrawingDataset(
                dataset=self.hp.dataset,
                max_len=self.hp.max_len,
                categories=self.hp.categories,
                max_per_category=self.hp.max_per_category,
                dataset_split=dataset_split,
                instruction_set=self.hp.instruction_set,
                prob_threshold=self.hp.prob_threshold)
            loader = DataLoader(ds,
                                batch_size=batch_size,
                                shuffle=shuffle,
                                collate_fn=ProgressionPairDataset.collate_fn)
        elif self.hp.instruction_set in ['stack', 'stack_leaves']:
            ds = SketchWithPlansConditionSegmentsDataset(
                dataset=self.hp.dataset,
                max_len=self.hp.max_len,
                categories=self.hp.categories,
                max_per_category=self.hp.hp.max_per_category,
                dataset_split=dataset_split,
                instruction_set=self.hp.instruction_set,
                prob_threshold=self.hp.prob_threshold)
            loader = DataLoader(
                ds,
                batch_size=batch_size,
                shuffle=shuffle,
                collate_fn=partial(
                    SketchWithPlansConditionSegmentsDataset.collate_fn,
                    token2idx=ds.token2idx))
        return loader

    def preprocess_batch_from_data_loader(self, batch):
        """
        Convert tensors to cuda and convert to [len, bsz, ...] instead of [bsz, len, ...]
        """
        preprocessed = []
        for item in batch:
            if type(item) == torch.Tensor:
                item = nn_utils.move_to_cuda(item)
                if item.dim() > 1:
                    item.transpose_(0, 1)
            preprocessed.append(item)
        return preprocessed

    def one_forward_pass(self, batch):
        """
        Return loss and other items of interest for one forward pass

        Args:
            batch: tuple from DataLoaders

        Returns:
            dict where 'loss': float Tensor must exist
        """
        strokes, stroke_lens, texts, text_lens, text_indices, cats, cats_idx, urls = batch
        # batch is 1st dimension (not 0th) due to preprocess_batch()

        # Create base inputs to decoder
        _, bsz, _ = strokes.size()
        sos = torch.stack([torch.Tensor([0, 0, 1, 0, 0])] * bsz).unsqueeze(
            0)  # start of sequence
        sos = nn_utils.move_to_cuda(sos)
        dec_inputs = torch.cat(
            [sos, strokes], dim=0
        )  # add sos at the begining of the strokes; [max_len + 1, bsz, 5]

        if self.hp.use_categories_dec:
            cat_embs = self.category_embedding(cats_idx)  # [bsz, cat_dim]
            cat_embs = cat_embs.repeat(dec_inputs.size(0), 1,
                                       1)  # [max_len + 1, bsz, cat_dim]
            dec_inputs = torch.cat([dec_inputs, cat_embs],
                                   dim=2)  # [max_len+1, bsz, 5 + cat_dim]

        #
        # Encode instructions, decode
        #
        if self.hp.instruction_set in ['stack', 'stack_leaves']:
            # text_indices: [max_seq_len, bsz, max_instruction_len], # text_lens: [max_seq_len, bsz]

            # decoder's hidden states are "matched" with language representations
            outputs, pi, mu_x, mu_y, sigma_x, sigma_y, rho_xy, q, _, _ = self.dec(
                dec_inputs, output_all=True)

            # # PLAN get hidden states and instruction stack for each batch input, for each segment
            # import pdb; pdb.set_trace()  # outputs: [max_seq_len, bsz, dim]
            # loss_match = 0.0

            # For each sequence in batch, divide drawing into segments (based on penup strokes)
            # For each segment, compute a matching loss between its hidden state and the
            # the encoded instruction stack for that segment
            all_instruction_embs = []
            all_seg_hiddens = []
            for i in range(bsz):
                penups = np.where(
                    strokes[:, i, :].cpu().numpy()[:, 3] == 1)[0].tolist()
                penups = ([0] + penups) if (
                    penups[0] != 0
                ) else penups  # first element could already be 0
                # TODO: find other place that I do [0] + penups, see if I need to account for the case
                # where the first element is 0

                # Encode instruction stacks
                # text_indices: [max_seq_len, bsz, max_instruction_len]
                instructions = [
                    text_indices[start_idx, i, :] for start_idx in penups[:-1]
                ]
                # Note on above:
                #   [:-1] because that's the end of the last segment
                #   instructions for each timestep within segment are the same, take the start_idx
                instructions = torch.stack(
                    instructions, dim=1
                )  # [max_instruction_len, n_segs] (max across all segs in batch)
                # (n_segs is the "batch" for the encoder)
                instructions_lens = [
                    text_lens[start_idx, i].item() for start_idx in penups[:-1]
                ]
                instructions = instructions[:max(
                    instructions_lens), :]  # encoder requires this
                cur_cats_idx = [
                    cats_idx[i] for _ in range(len(instructions_lens))
                ]  # all segs are from same drawing (i.e. same category)
                instruction_embs = self.enc(
                    instructions,
                    instructions_lens,
                    self.text_embedding,
                    category_embedding=None,
                    categories=cur_cats_idx)  # [n_segs, dim]
                all_instruction_embs.append(instruction_embs)

                # Compute hidden states mean for each seg
                seg_hiddens = []
                for j in range(len(penups) - 1):  # n_segs
                    start_idx = penups[j]
                    end_idx = penups[j + 1]
                    seg_outputs = outputs[start_idx:end_idx + 1,
                                          i, :]  # [seg_len, dim]
                    seg_hidden = seg_outputs.mean(dim=0)  # [dim]
                    seg_hiddens.append(seg_hidden)
                seg_hiddens = torch.stack(seg_hiddens, dim=0)  # [n_segs, dim]
                all_seg_hiddens.append(seg_hiddens)

            # Concate all segs across all batch items
            if self.hp.loss_match == 'triplet':
                all_instruction_embs = torch.cat(
                    all_instruction_embs, dim=0)  # [n_total_segs, enc_dim]
                all_seg_hiddens = torch.cat(all_seg_hiddens,
                                            dim=0)  # [n_total_segs, dec_dim]
                all_seg_hiddens = self.fc_dec(
                    all_seg_hiddens)  # [n_total_segs, enc_dim]

                # Compute triplet loss
                pos = (all_seg_hiddens -
                       all_instruction_embs)**2  # [n_total_segs, enc_dim]
                all_instruction_embs_shuffled = all_instruction_embs[
                    torch.randperm(pos.size(0)), :]  # [n_total_segs, enc_dim]
                neg = (all_seg_hiddens - all_instruction_embs_shuffled
                       )**2  # [n_total_segs, enc_dim]
                loss_match = (pos - neg).mean() + torch.tensor(0.1).to(
                    pos.device)  # positive - negative + alpha
                loss_match = max(torch.tensor(0.0), loss_match)
            elif self.hp.loss_match == 'decode':
                raise NotImplementedError

            # TODO: check if text_indices is correct

        elif self.hp.instruction_set in [
                'toplevel', 'toplevel_leaves', 'leaves'
        ]:
            # triplet loss
            if self.hp.cond_instructions == 'decinputs':  # concatenate instruction embedding to every time step
                # Encode instructions
                # text_indices: [len, bsz], text_lens: [bsz]
                instructions_emb = self.enc(
                    text_indices,
                    text_lens,
                    self.text_embedding,
                    category_embedding=self.category_embedding,
                    categories=cats_idx)  # [bsz, enc_dim]

                # decode
                instructions_emb = instructions_emb.unsqueeze(
                    0)  #  [1, bsz, dim]
                instructions_emb = instructions_emb.repeat(
                    dec_inputs.size(0), 1, 1)  # [max_len + 1, bsz, dim]
                dec_inputs = torch.cat([dec_inputs, instructions_emb],
                                       dim=2)  # [max_len + 1, bsz, inp_dim]
                outputs, pi, mu_x, mu_y, sigma_x, sigma_y, rho_xy, q, _, _ = self.dec(
                    dec_inputs, output_all=True)

            elif self.hp.cond_instructions == 'match':  # match decoder's hidden representations to encoded language
                # decode
                outputs, pi, mu_x, mu_y, sigma_x, sigma_y, rho_xy, q, _, _ = self.dec(
                    dec_inputs,
                    output_all=True)  # outputs: [max_seq_len, bsz, dim]
                outputs = outputs.mean(dim=0)  # [bsz, dec_dim]

                if self.hp.loss_match == 'triplet':
                    # Encode instructions
                    # text_indices: [len, bsz], text_lens: [bsz]
                    instructions_emb = self.enc(
                        text_indices,
                        text_lens,
                        self.text_embedding,
                        category_embedding=self.category_embedding,
                        categories=cats_idx)  # [bsz, enc_dim]
                    outputs = self.fc_dec(outputs)  # [bsz, enc_dim]

                    pos = (outputs - instructions_emb)**2  # [bsz, enc_dim]
                    instructions_emb_shuffled = instructions_emb[
                        torch.randperm(bsz), :]  # [bsz, enc_dim]
                    neg = (outputs -
                           instructions_emb_shuffled)**2  # [bsz, enc_dim]
                    loss_match = (pos - neg).mean() + torch.tensor(0.1).to(
                        pos.device)  # positive - negative + alpha
                    loss_match = max(torch.tensor(0.0), loss_match)
                elif self.hp.loss_match == 'decode':
                    hidden = nn_utils.move_to_cuda(
                        torch.zeros(self.ins_dec.num_layers, bsz,
                                    self.ins_dec.hidden_dim))
                    cell = nn_utils.move_to_cuda(
                        torch.zeros(self.ins_dec.num_layers, bsz,
                                    self.ins_dec.hidden_dim))

                    # Decode
                    texts_emb = self.text_embedding(
                        text_indices)  # [len, bsz, dim]
                    logits, texts_hidden = self.ins_dec(
                        texts_emb,
                        text_lens,
                        hidden=hidden,
                        cell=cell,
                        token_embedding=self.text_embedding,
                        category_embedding=self.category_embedding,
                        categories=cats_idx)
                    loss_match = self.compute_dec_loss(logits, text_indices)

        #
        # Calculate reconstruction and final loss
        #
        mask, dx, dy, p = self.dec.make_target(strokes, stroke_lens, self.hp.M)

        loss_R = self.dec.reconstruction_loss(mask, dx, dy, p, pi, mu_x, mu_y,
                                              sigma_x, sigma_y, rho_xy, q)

        if self.hp.cond_instructions == 'decinputs':
            loss = loss_R
            result = {'loss': loss, 'loss_R': loss_R.clone().detach()}
        elif self.hp.cond_instructions == 'match':
            loss = loss_R + loss_match
            result = {
                'loss': loss,
                'loss_R': loss_R.clone().detach(),
                'loss_match': loss_match.clone().detach()
            }

        if ((loss != loss).any() or (loss == float('inf')).any()
                or (loss == float('-inf')).any()):
            raise Exception('Nan in SketchRNnDecoderGMMOnly forward pass')

        return result

    def compute_dec_loss(self, logits, tf_inputs):
        """
        Args:
            logits: [len, bsz, vocab]
            tf_inputs: [len, bsz] ("teacher-forced inputs", inputs to decoder used to generate logits)
                (text_indices_w_sos_eos)
        """
        logits = logits[:
                        -1, :, :]  # last input that produced logits is EOS. Don't care about the EOS -> mapping
        targets = tf_inputs[1::, :]  # remove first input (sos)

        vocab_size = logits.size(-1)
        logits = logits.reshape(-1, vocab_size)
        targets = targets.reshape(-1)
        loss = F.cross_entropy(logits, targets, ignore_index=0)

        return loss

    ##############################################################################
    # Generate
    ##############################################################################
    def end_of_epoch_hook(self,
                          data_loader,
                          epoch,
                          outputs_path=None,
                          writer=None):
        # match not implemented / tested yet
        if (self.hp.cond_instructions == 'decinputs') and \
            (self.hp.instruction_set in ['toplevel', 'toplevel_leaves', 'leaves']):
            self.generate_and_save(data_loader,
                                   epoch,
                                   25,
                                   outputs_path=outputs_path)

    def generate_and_save(self, data_loader, epoch, n_gens, outputs_path=None):
        """
        Generate and save drawings
        """
        n = 0
        gen_strokes = []
        gt_strokes = []
        gt_texts = []
        for i, batch in enumerate(data_loader):
            batch = self.preprocess_batch_from_data_loader(batch)
            strokes, stroke_lens, texts, text_lens, text_indices, cats, cats_idx, urls = batch

            max_len, bsz, _ = strokes.size()

            if self.hp.cond_instructions == 'decinputs':
                # Encode instructions
                # text_indices: [len, bsz], text_lens: [bsz]
                instructions_emb = self.enc(
                    text_indices,
                    text_lens,
                    self.text_embedding,
                    category_embedding=self.category_embedding,
                    categories=cats_idx)  # [bsz, enc_dim]
                z = instructions_emb

            hidden_cell = (nn_utils.move_to_cuda(
                torch.zeros(1, bsz, self.hp.dec_dim)),
                           nn_utils.move_to_cuda(
                               torch.zeros(1, bsz, self.hp.dec_dim)))

            # initialize state with start of sequence stroke-5 stroke
            sos = torch.stack([torch.Tensor([0, 0, 1, 0, 0])] * bsz).unsqueeze(
                0)  # [1 (len), bsz, 5 (stroke-5)]
            sos = nn_utils.move_to_cuda(sos)

            # generate until end of sequence or maximum sequence length
            s = sos
            seq_x = []  # delta-x
            seq_y = []  # delta-y
            seq_pen = []  # pen-down
            for _ in range(max_len):
                if self.hp.cond_instructions == 'decinputs':  # input is last state, z, and hidden_cell
                    input = torch.cat(
                        [s, z.unsqueeze(0)], dim=2
                    )  # [1 (len), 1 (bsz), input_dim (5) + z_dim (128)]

                elif self.hp.cond_instructions == 'match':  # input is last state and hidden_cell
                    input = s  # [1, bsz (1), 5]

                if self.hp.use_categories_dec \
                    and hasattr(self, 'category_embedding'):
                    # hack because VAE was trained with use_categories_dec=True but didn't actually have a category embedding
                    cat_embs = self.category_embedding(
                        cats_idx)  # [bsz (1), cat_dim]
                    input = torch.cat([input, cat_embs.unsqueeze(0)],
                                      dim=2)  # [1, 1, dim]
                    # dim = 5 + cat_dim if decodergmm, 5 + z_dim + cat_dim if vae

                outputs, pi, mu_x, mu_y, sigma_x, sigma_y, rho_xy, q, hidden, cell = \
                    self.dec(input, stroke_lens=stroke_lens, output_all=False, hidden_cell=hidden_cell)
                hidden_cell = (hidden, cell)  # for next timee step
                # sample next state
                s, dx, dy, pen_up, eos = self.sample_next_state(
                    pi, mu_x, mu_y, sigma_x, sigma_y, rho_xy, q)

                seq_x.append(dx)
                seq_y.append(dy)
                seq_pen.append(pen_up)

                if eos:  # done drawing
                    break

            # get in format to draw image
            # Cumulative sum because seq_x and seq_y are deltas, so get x (or y) at each stroke
            sample_x = np.cumsum(seq_x, 0)
            sample_y = np.cumsum(seq_y, 0)
            sample_pen = np.array(seq_pen)
            sequence = np.stack([sample_x, sample_y, sample_pen]).T
            # output_fp = os.path.join(outputs_path, f'e{epoch}-gen{n}.jpg')
            # save_strokes_as_img(sequence, output_fp)

            # Save original as well
            output_fp = os.path.join(outputs_path, f'e{epoch}-gt{n}.jpg')
            strokes_x = strokes[:, 0,
                                0]  # first 0 for x because sample_next_state etc. only using 0-th batch item; 2nd 0 for dx
            strokes_y = strokes[:, 0, 1]  # 1 for dy
            strokes_x = np.cumsum(strokes_x.cpu().numpy())
            strokes_y = np.cumsum(strokes_y.cpu().numpy())
            strokes_pen = strokes[:, 0, 3].cpu().numpy()
            strokes_out = np.stack([strokes_x, strokes_y, strokes_pen]).T
            # save_strokes_as_img(strokes_out, output_fp)

            gen_strokes.append(sequence)
            gt_strokes.append(strokes_out)
            gt_texts.append(texts[0])  # 0 because batch size is 1

            n += 1
            if n == n_gens:
                break

        # save grid drawings
        rowcol_size = 5
        chunk_size = rowcol_size**2
        for i in range(0, chunk_size, len(gen_strokes)):
            output_fp = os.path.join(outputs_path,
                                     f'e{epoch}_gen{i}-{i+chunk_size}.jpg')
            save_multiple_strokes_as_img(gen_strokes[i:i + chunk_size],
                                         output_fp)

            output_fp = os.path.join(outputs_path,
                                     f'e{epoch}_gt{i}-{i+chunk_size}.jpg')
            save_multiple_strokes_as_img(gt_strokes[i:i + chunk_size],
                                         output_fp)

            # save texts
            output_fp = os.path.join(outputs_path,
                                     f'e{epoch}_texts{i}-{i+chunk_size}.json')
            utils.save_file(gt_texts[i:i + chunk_size], output_fp)

    ##############################################################################
    # Load
    ##############################################################################
    def load_enc_and_catemb_from_pretrained(self, dir):
        """
        Copy pretrained encoder and category embedding weights. They must
        be the same dimension as self's dimension.s.

        Args:
            dir: str
        """
        print(
            'Loading encoder and category embedidng weights from pretrained: ',
            dir)
        fp = os.path.join(dir, 'model.pt')
        trained_state_dict = torch.load(fp)
        self_state_dict = self.state_dict()

        for name, param in trained_state_dict.items():
            if name.startswith('enc') or name.startswith('category_embedding'):
                self_state_dict[name].copy_(param)
示例#8
0
class SketchRNNVAEModel(SketchRNNModel):
    """
    Create entire sketch model by combining encoder and decoder. Training, generation, etc.
    """

    def __init__(self, hp, save_dir, skip_data=False):
        super().__init__(hp, save_dir, skip_data=skip_data)

        # Model
        self.enc = SketchRNNVAEEncoder(5, hp.enc_dim, hp.enc_num_layers, hp.z_dim,dropout=hp.dropout,
                                       use_layer_norm=hp.use_layer_norm, rec_dropout=hp.rec_dropout)
                                    #    use_layer_norm=False)
        self.fc_z_to_hc = nn.Linear(hp.z_dim, 2 * hp.dec_dim)  # 2: 1 for hidden, 1 for cell

        # Model
        inp_dim = 5 + hp.z_dim
        self.category_embedding = None
        if hp.use_categories_dec:
            self.category_embedding = nn.Embedding(35, self.hp.categories_dim)
            self.models.append(self.category_embedding)
            inp_dim += hp.categories_dim
        self.dec = SketchRNNDecoderGMM(inp_dim, hp.dec_dim, hp.M, dropout=hp.dropout,
            use_layer_norm=self.hp.use_layer_norm, rec_dropout=self.hp.rec_dropout)

        self.models.extend([self.enc, self.fc_z_to_hc, self.dec])
        if USE_CUDA:
            for model in self.models:
                model.cuda()

        # optimization -- ADAM plus annealing (supp eq. 4)
        self.optimizers.append(optim.Adam(self.parameters(), hp.lr))

    def one_forward_pass(self, batch):
        """
        Return loss and other items of interest for one forward pass

        Args:
            batch: tuple from loader

        Returns:
            dict where 'loss': float Tensor must exist
        """
        strokes, stroke_lens, cats, cats_idx = batch
        max_len, bsz, _ = strokes.size()

        # Encode
        z, mu, sigma_hat = self.enc(strokes)  # each [bsz, z_dim]

        # Create inputs to decoder
        sos = torch.stack([torch.Tensor([0, 0, 1, 0, 0])] * bsz).unsqueeze(0)  # start of sequence
        sos = nn_utils.move_to_cuda(sos)
        inputs_init = torch.cat([sos, strokes], 0)  # add sos at the begining of the strokes; [max_len + 1, bsz, 5]
        z_stack = torch.stack([z] * (max_len + 1), dim=0)  # expand z to concat with inputs; [max_len + 1, bsz, z_dim]
        dec_inputs = torch.cat([inputs_init, z_stack], 2)  # each input is stroke + z; [max_len + 1, bsz, 5 + z_dim]

        if self.hp.use_categories_dec:
            cat_embs = self.category_embedding(cats_idx)  # [bsz, cat_dim]
            cat_embs = cat_embs.repeat(dec_inputs.size(0), 1, 1)  # [max_len + 1, bsz, cat_dim]
            dec_inputs = torch.cat([dec_inputs, cat_embs], dim=2)  # [max_len+1, bsz, 5 + z_dim + cat_dim]

        # init hidden and cell states is tanh(fc(z)) (Page 3)
        hidden, cell = torch.split(torch.tanh(self.fc_z_to_hc(z)), self.hp.dec_dim, 1)
        hidden_cell = (hidden.unsqueeze(0).contiguous(), cell.unsqueeze(0).contiguous())
        # TODO: if we want multiple layers, we need to replicate hidden and cell n_layers times

        # Decode
        _, pi, mu_x, mu_y, sigma_x, sigma_y, rho_xy, q, _, _ = self.dec(dec_inputs, output_all=True,
                                                                        hidden_cell=hidden_cell)

        # Calculate losses
        mask, dx, dy, p = self.dec.make_target(strokes, stroke_lens, self.hp.M)
        loss_KL_final, loss_KL_thresh, loss_KL  = self.kullback_leibler_loss(sigma_hat, mu, self.hp.KL_min, self.hp.wKL, self.eta_step)
        loss_R = self.dec.reconstruction_loss(mask,
                                              dx, dy, p,
                                              pi, mu_x, mu_y, sigma_x, sigma_y, rho_xy,
                                              q)
        loss = loss_KL_final + loss_R
        result = {'loss': loss, 'loss_R': loss_R,
                  'loss_KL_final': loss_KL_final, 'loss_KL_thresh': loss_KL_thresh, 'loss_KL': loss_KL,
                  'loss_KL_eta': torch.Tensor([self.eta_step]) # "loss" for logging purposes; convert to Tensor because .item() is called on each value
                  }

        return result

    #
    # KL loss
    #
    def kullback_leibler_loss(self, sigma_hat, mu, KL_min, wKL, eta_step):
        """
        Calculate KL loss -- (eq. 10, 11)

        Args:
            sigma_hat: [bsz, z_dim]
            mu: [bsz, z_dim]
            KL_min: float
            wKL: float
            eta_step: float

        Returns: float Tensor
        """
        bsz, z_dim = sigma_hat.size()

        LKL = -0.5 * torch.sum(1 + sigma_hat - mu ** 2 - torch.exp(sigma_hat)) \
              / float(bsz * z_dim)
        KL_min = torch.Tensor([KL_min])
        KL_min = nn_utils.move_to_cuda(KL_min)
        KL_min = KL_min.detach()

        LKL_thresh = torch.max(LKL, KL_min)
        LKL_final = wKL * eta_step * LKL_thresh
        return LKL_final, LKL_thresh, LKL

    def save_vaez_inference_time(self, model_dir):
        """
        Loop over all ndjson data and save the encoded z's for each.

        Args:
            model_dir (str): directory of the trained model we are using
        """
        torch.backends.cudnn.benchmark = True # Optimizes cudnn
        print('Saving to: ', os.path.join(model_dir, 'inference_vaez'))
        with torch.no_grad():
            for split in ['train', 'valid', 'test']:
                loader = self.get_data_loader(split, self.hp.batch_size, self.hp.categories, self.hp.max_len, self.hp.max_per_category, False)

                for bidx, batch in enumerate(loader):
                    # encode drawing and get z
                    batch = self.preprocess_batch_from_data_loader(batch)
                    strokes, stroke_lens, cats, cats_idx = batch
                    max_len, cur_bsz, _ = strokes.size()
                    z, _, _ = self.enc(strokes)  # z: [bsz, 128]
                    z_np = z.cpu().numpy()

                    for i in range(cur_bsz):
                        # get index of this drawing within the dataset (loader has shuffle=False so we can do this)
                        global_idx = bidx * self.hp.batch_size + i
                        drawing_id = loader.dataset.data[global_idx]['id']

                        # save numpy version of z
                        out_dir = os.path.join(model_dir, 'inference_vaez', cats[i])
                        os.makedirs(out_dir, exist_ok=True)
                        out_fp = os.path.join(out_dir, f'{drawing_id}.pkl')

                        cur_z_np = z_np[i]
                        utils.save_file(cur_z_np, out_fp)