示例#1
0
    def __init__(self, config, use_attention=True, encoder=None, decoder=None):
        super(seq2seq, self).__init__()

        if encoder is not None:
            self.encoder = encoder
        else:
            if config.transformer:
                # n_src_vocab, len_max_seq, d_word_vec,n_layers, n_head, d_k, d_v,d_model, d_inner,
                self.encoder = Encoder(config.src_vocab_size, 100,
                                       config.emb_size, 6, 8,
                                       int(config.hidden_size / 8),
                                       int(config.hidden_size / 8),
                                       config.hidden_size,
                                       config.hidden_size * 4)
            else:
                self.encoder = models.rnn_encoder(config)
        tgt_embedding = self.encoder.embedding if config.shared_vocab else None
        if decoder is not None:
            self.decoder = decoder
        else:
            self.decoder = models.rnn_decoder(config,
                                              embedding=tgt_embedding,
                                              use_attention=use_attention)
        self.log_softmax = nn.LogSoftmax(dim=-1)
        self.use_cuda = config.use_cuda
        self.config = config
        self.criterion = nn.CrossEntropyLoss(ignore_index=utils.PAD,
                                             reduction='none')
        if config.use_cuda:
            self.criterion.cuda()
    def __init__(self,
                 label_num=4,
                 fix_bert=False,
                 context_encoder=None,
                 bilstm=True,
                 dropout=0,
                 double_supervision=False,
                 emoji_vectors=None):
        super(HierBertModel, self).__init__()
        self.sentences_encoder = BertModel.from_pretrained('bert-base-cased')
        self.context_encoder = context_encoder
        self.bilstm = bilstm
        self.double_supervision = double_supervision
        self.emoji_emb = nn.Embedding.from_pretrained(
            torch.FloatTensor(emoji_vectors))
        self.emoji_dim = emoji_vectors.shape[1]
        if fix_bert:
            for param in self.sentences_encoder.parameters():
                param.requires_grad = False

        self.hidden_size = self.sentences_encoder.config.hidden_size
        self.dropout = torch.nn.Dropout(dropout)
        if context_encoder == 'lstm':
            self.hier_encoder = nn.LSTM(self.hidden_size,
                                        self.hidden_size,
                                        batch_first=True,
                                        bidirectional=bilstm)
        else:
            self.hier_encoder = Encoder(
                self.hidden_size + self.emoji_dim if
                (emoji_vectors is not None) else self.hidden_size,
                self.hidden_size,
                constant.hop,
                constant.heads,
                constant.depth,
                constant.depth,
                constant.filter,
                max_length=3,
                input_dropout=0,
                layer_dropout=0,
                attention_dropout=0,
                relu_dropout=0,
                use_mask=False,
                act=constant.act)
        self.classifer = nn.Linear(self.hidden_size, label_num)
        if double_supervision:
            self.super = nn.Linear(
                (self.hidden_size + self.emoji_dim) * 3 if
                (emoji_vectors is not None) else self.hidden_size * 3,
                label_num)

        if (self.context_encoder == 'lstm' and self.bilstm):
            self.classifer = nn.Linear(self.hidden_size * 2, label_num)
    def __init__(self,
                 vocab_dict,
                 dropout_rate,
                 embed_dim,
                 hidden_dim,
                 bidirectional=True):
        super(AoAReader, self).__init__()
        self.vocab_dict = vocab_dict
        self.hidden_dim = hidden_dim
        self.embed_dim = embed_dim
        self.dropout_rate = dropout_rate

        self.embedding = nn.Embedding(vocab_dict.size(),
                                      self.embed_dim,
                                      padding_idx=Constants.PAD)
        self.embedding.weight.data.uniform_(-0.05, 0.05)

        input_size = self.embed_dim

        self.transformer = Encoder(embed_dim,
                                   hidden_dim,
                                   1,
                                   1,
                                   hidden_dim,
                                   hidden_dim,
                                   64,
                                   max_length=2000,
                                   input_dropout=0.0,
                                   layer_dropout=0.0,
                                   attention_dropout=0.0,
                                   relu_dropout=0.0,
                                   use_mask=False)

        for weight in self.transformer.parameters():
            if len(weight.size()) > 1:
                weigth_init.orthogonal(weight.data)
示例#4
0
    def __init__(self, encoder_latent, encoder_out, fspool_n_pieces, transformer_layers, transformer_attn_size,
                 transformer_num_heads, num_element_features, size_pred_width, pad_value, max_set_size):
        super(Tspn, self).__init__()

        self.pad_value = pad_value
        self.max_set_size = max_set_size
        self.num_element_features = num_element_features

        self._prior = SetPrior(num_element_features)

        self._encoder = FSEncoder(encoder_latent, encoder_out, fspool_n_pieces)
        self._transformer = Encoder(transformer_layers, transformer_attn_size, transformer_num_heads)

        # initialise the output to predict points at the center of our canvas
        self._set_prediction = tf.keras.layers.Conv1D(num_element_features, 1, kernel_initializer='zeros',
                                                     bias_initializer=tf.keras.initializers.constant(0.5),
                                                     use_bias=True)

        self._size_predictor = SizePredictor(size_pred_width, max_set_size)
示例#5
0
    def __init__(self, yaml_path):
        config_file = yaml_path
        config = yaml.load(open(config_file), Loader=yaml.FullLoader)
        args = config["training"]
        SEED = args["seed"]
        DATASET = args["dataset"]  # Multi30k or ISWLT
        MODEL = args["model"]  # gru**2, gru_attn**2, transformer, gcn_gru, gcngru_gru, gcngruattn_gru, gcnattn_gru
        REVERSE = args["reverse"]
        BATCH_SIZE = args["batch_size"]
        ENC_EMB_DIM = args["encoder_embed_dim"]
        DEC_EMB_DIM = args["decoder_embed_dim"]
        ENC_HID_DIM = args["encoder_hidden_dim"]
        DEC_HID_DIM = args["decoder_hidden_dim"]
        ENC_DROPOUT = args["encoder_dropout"]
        DEC_DROPOUT = args["decoder_dropout"]
        NLAYERS = args["num_layers"]
        N_EPOCHS = args["num_epochs"]
        CLIP = args["grad_clip"]
        LR = args["lr"]
        LR_DECAY_RATIO = args["lr_decay_ratio"]
        ID = args["id"]
        PATIENCE = args["patience"]
        DIR = 'checkpoints/{}-{}-{}/'.format(DATASET, MODEL, ID)
        MODEL_PATH = DIR
        LOG_PATH = '{}test-log.log'.format(DIR)
        set_seed(SEED)
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        self.config = args
        self.device = device

        if 'transformer' in MODEL:
            ENC_HEADS = args["encoder_heads"]
            DEC_HEADS = args["decoder_heads"]
            ENC_PF_DIM = args["encoder_pf_dim"]
            DEC_PF_DIM = args["decoder_pf_dim"]
            MAX_LEN = args["max_len"]
            
        SRC = Field(tokenize = lambda text: tokenize_de(text, REVERSE), 
                    init_token = '<sos>', 
                    eos_token = '<eos>', 
                    lower = True)
        TGT = Field(tokenize = tokenize_en, 
                    init_token = '<sos>', 
                    eos_token = '<eos>', 
                    lower = True)
        GRH = RawField(postprocessing=batch_graph)
        data_fields = [('src', SRC), ('trg', TGT), ('grh', GRH)]
        
        train_data = Dataset(torch.load("data/Multi30k/train_data.pt"), data_fields)
        valid_data = Dataset(torch.load("data/Multi30k/valid_data.pt"), data_fields)
        test_data = Dataset(torch.load("data/Multi30k/test_data.pt"), data_fields)
        self.train_data, self.valid_data, self.test_data = train_data, valid_data, test_data
        
        train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
            (train_data, valid_data, test_data), 
            batch_size = BATCH_SIZE, 
            sort_key = lambda x: len(x.src),
            sort_within_batch=False,
            device = device)
        self.train_iterator, self.valid_iterator, self.test_iterator = train_iterator, valid_iterator, test_iterator
        
        SRC.build_vocab(train_data, min_freq = 2)
        TGT.build_vocab(train_data, min_freq = 2)
        self.SRC, self.TGT, self.GRH = SRC, TGT, GRH

        print(f"Number of training examples: {len(train_data.examples)}")
        print(f"Number of validation examples: {len(valid_data.examples)}")
        print(f"Number of testing examples: {len(test_data.examples)}")
        print(f"Unique tokens in source (de) vocabulary: {len(SRC.vocab)}")
        print(f"Unique tokens in target (en) vocabulary: {len(TGT.vocab)}")

        src_c, tgt_c = get_sentence_lengths(train_data)
        src_lengths = counter2array(src_c)
        tgt_lengths = counter2array(tgt_c)

        print("maximum src, tgt sent lengths: ")
        np.quantile(src_lengths, 1), np.quantile(tgt_lengths, 1)

        # Get models and corresponding training scripts

        INPUT_DIM = len(SRC.vocab)
        OUTPUT_DIM = len(TGT.vocab)
        SRC_PAD_IDX = SRC.vocab.stoi[SRC.pad_token]
        TGT_PAD_IDX = TGT.vocab.stoi[TGT.pad_token]
        self.SRC_PAD_IDX = SRC_PAD_IDX
        self.TGT_PAD_IDX = TGT_PAD_IDX

        if MODEL == "gru**2":  # gru**2, gru_attn**2, transformer, gcn_gru
            from models.gru_seq2seq import GRUEncoder, GRUDecoder, Seq2Seq
            enc = GRUEncoder(INPUT_DIM, ENC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, NLAYERS, ENC_DROPOUT)
            dec = GRUDecoder(OUTPUT_DIM, DEC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, NLAYERS, DEC_DROPOUT)
            model = Seq2Seq(enc, dec, device).to(device)

            from src.train import train_epoch_gru, evaluate_gru, epoch_time
            train_epoch = train_epoch_gru
            evaluate = evaluate_gru
            
            self.enc, self.dec, self.model, self.train_epoch, self.evaluate = enc, dec, model, train_epoch, evaluate
            
        elif MODEL == "gru_attn**2":
            from models.gru_attn import GRUEncoder, GRUDecoder, Seq2Seq, Attention
            attn = Attention(ENC_HID_DIM, DEC_HID_DIM)
            enc = GRUEncoder(INPUT_DIM, ENC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, NLAYERS, ENC_DROPOUT)
            dec = GRUDecoder(OUTPUT_DIM, DEC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, NLAYERS, DEC_DROPOUT, attn)
            model = Seq2Seq(enc, dec, device).to(device)

            from src.train import train_epoch_gru_attn, evaluate_gru_attn, epoch_time
            train_epoch = train_epoch_gru_attn
            evaluate = evaluate_gru_attn
            
            self.enc, self.dec, self.model, self.train_epoch, self.evaluate, self.attn = enc, dec, model, train_epoch, evaluate, attn

        elif MODEL == "transformer":
            from models.transformer import Encoder, Decoder, Seq2Seq
            enc = Encoder(INPUT_DIM, ENC_HID_DIM, NLAYERS, ENC_HEADS, 
                          ENC_PF_DIM, ENC_DROPOUT, device, MAX_LEN)
            dec = Decoder(OUTPUT_DIM, DEC_HID_DIM, NLAYERS, DEC_HEADS, 
                          DEC_PF_DIM, DEC_DROPOUT, device, MAX_LEN)
            model = Seq2Seq(enc, dec, SRC_PAD_IDX, TGT_PAD_IDX, device).to(device)

            from src.train import train_epoch_tfmr, evaluate_tfmr, epoch_time
            train_epoch = train_epoch_tfmr
            evaluate = evaluate_tfmr

            self.enc, self.dec, self.model, self.train_epoch, self.evaluate = enc, dec, model, train_epoch, evaluate
            
        elif MODEL == "gcn_gru":
            from models.gru_seq2seq import GCNEncoder, GRUDecoder, GCN2Seq
            enc = GCNEncoder(INPUT_DIM, ENC_EMB_DIM, ENC_HID_DIM, NLAYERS, ENC_DROPOUT)
            dec = GRUDecoder(OUTPUT_DIM, DEC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, NLAYERS, DEC_DROPOUT)
            model = GCN2Seq(enc, dec, device).to(device)

            from src.train import train_epoch_gcn_gru, evaluate_gcn_gru, epoch_time
            train_epoch = train_epoch_gcn_gru
            evaluate = evaluate_gcn_gru

            self.enc, self.dec, self.model, self.train_epoch, self.evaluate = enc, dec, model, train_epoch, evaluate
            
        elif MODEL == "gcngru_gru":
            from models.gru_seq2seq import GCNGRUEncoder, GRUDecoder, GCN2Seq
            enc = GCNGRUEncoder(INPUT_DIM, ENC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, NLAYERS, ENC_DROPOUT, device)
            dec = GRUDecoder(OUTPUT_DIM, DEC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, NLAYERS, DEC_DROPOUT)
            model = GCN2Seq(enc, dec, device).to(device)

            from src.train import train_epoch_gcn_gru, evaluate_gcn_gru, epoch_time
            train_epoch = train_epoch_gcn_gru
            evaluate = evaluate_gcn_gru

            self.enc, self.dec, self.model, self.train_epoch, self.evaluate = enc, dec, model, train_epoch, evaluate
            
        elif MODEL == "gcnattn_gru":
            from models.gru_attn import GCNEncoder, GRUDecoder, GCN2Seq, Attention
            attn = Attention(ENC_HID_DIM, DEC_HID_DIM)
            enc = GCNEncoder(INPUT_DIM, ENC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, NLAYERS, ENC_DROPOUT)
            dec = GRUDecoder(OUTPUT_DIM, DEC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, NLAYERS, DEC_DROPOUT, attn)
            model = GCN2Seq(enc, dec, device).to(device)

            from src.train import train_epoch_gcnattn_gru, evaluate_gcnattn_gru, epoch_time
            train_epoch = train_epoch_gcnattn_gru
            evaluate = evaluate_gcnattn_gru
            
            self.enc, self.dec, self.model, self.train_epoch, self.evaluate, self.attn = enc, dec, model, train_epoch, evaluate, attn

        elif MODEL == "gcngruattn_gru":
            from models.gru_attn import GCNGRUEncoder, GRUDecoder, GCN2Seq, Attention
            attn = Attention(ENC_HID_DIM, DEC_HID_DIM)
            enc = GCNGRUEncoder(INPUT_DIM, ENC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, NLAYERS, ENC_DROPOUT, device)
            dec = GRUDecoder(OUTPUT_DIM, DEC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, NLAYERS, DEC_DROPOUT, attn)
            model = GCN2Seq(enc, dec, device).to(device)

            from src.train import train_epoch_gcnattn_gru, evaluate_gcnattn_gru, epoch_time
            train_epoch = train_epoch_gcnattn_gru
            evaluate = evaluate_gcnattn_gru
            
            self.enc, self.dec, self.model, self.train_epoch, self.evaluate, self.attn = enc, dec, model, train_epoch, evaluate, attn

        else:
            raise ValueError("Wrong model choice")

        if 'gcn' in MODEL:
            from src.utils import init_weights_uniform as init_weights
        else: 
            from src.utils import init_weights_xavier as init_weights

        model.apply(init_weights)
        n_params = count_parameters(model)
        print("Model initialized...{} params".format(n_params))
        
        self.criterion = nn.CrossEntropyLoss(ignore_index=TGT_PAD_IDX)
        
        print(os.path.join(MODEL_PATH, "checkpoint.pt"))
#         try:
#             state_dict = torch.load(os.path.join(MODEL_PATH, "checkpoint.pt"), map_location=device)['model_state_dict']
#         except:
#             state_dict = torch.load(os.path.join(MODEL_PATH, "checkpoint.pt"), map_location=device)
        state_dict = torch.load(os.path.join(MODEL_PATH, "checkpoint.pt"), map_location=device)
        if 'model_state_dict' in state_dict:
            state_dict = state_dict['model_state_dict']
        model.load_state_dict(state_dict)
        self.model = model
class AoAReader(nn.Module):
    def __init__(self,
                 vocab_dict,
                 dropout_rate,
                 embed_dim,
                 hidden_dim,
                 bidirectional=True):
        super(AoAReader, self).__init__()
        self.vocab_dict = vocab_dict
        self.hidden_dim = hidden_dim
        self.embed_dim = embed_dim
        self.dropout_rate = dropout_rate

        self.embedding = nn.Embedding(vocab_dict.size(),
                                      self.embed_dim,
                                      padding_idx=Constants.PAD)
        self.embedding.weight.data.uniform_(-0.05, 0.05)

        input_size = self.embed_dim

        self.transformer = Encoder(embed_dim,
                                   hidden_dim,
                                   1,
                                   1,
                                   hidden_dim,
                                   hidden_dim,
                                   64,
                                   max_length=2000,
                                   input_dropout=0.0,
                                   layer_dropout=0.0,
                                   attention_dropout=0.0,
                                   relu_dropout=0.0,
                                   use_mask=False)

        for weight in self.transformer.parameters():
            if len(weight.size()) > 1:
                weigth_init.orthogonal(weight.data)

    def forward(self,
                docs_input,
                docs_len,
                doc_mask,
                querys_input,
                querys_len,
                query_mask,
                candidates=None,
                answers=None):
        s_docs, s_docs_len, reverse_docs_idx = sort_batch(docs_input, docs_len)
        s_querys, s_querys_len, reverse_querys_idx = sort_batch(
            querys_input, querys_len)

        docs_embedding = self.embedding(s_docs)
        querys_embedding = self.embedding(s_querys)

        docs_outputs = self.transformer(docs_embedding)
        querys_outputs = self.transformer(querys_embedding)

        docs_outputs = docs_outputs[reverse_docs_idx.data]
        querys_outputs = querys_outputs[reverse_querys_idx.data]

        # transpose query for pair-wise dot product
        dos = docs_outputs
        doc_mask = doc_mask.unsqueeze(2)
        qos = torch.transpose(querys_outputs, 1, 2)
        query_mask = query_mask.unsqueeze(2)

        # pair-wise matching score
        M = torch.bmm(dos, qos)
        M_mask = torch.bmm(doc_mask, query_mask.transpose(1, 2))
        # query-document attention
        alpha = softmax_mask(M, M_mask, axis=1)
        beta = softmax_mask(M, M_mask, axis=2)

        sum_beta = torch.sum(beta, dim=1, keepdim=True)

        docs_len = docs_len.unsqueeze(1).unsqueeze(2).expand_as(sum_beta)
        average_beta = sum_beta / docs_len.float()

        # attended document-level attention
        s = torch.bmm(alpha, average_beta.transpose(1, 2))
        # predict the most possible answer from given candidates
        pred_answers = None
        #pred_locs = None
        probs = None
        if candidates is not None:
            pred_answers = []
            pred_locs = []
            for i, cands in enumerate(candidates):
                pb = []
                document = docs_input[i].squeeze()
                for j, candidate in enumerate(cands):
                    pointer = document == candidate.expand_as(document)
                    pb.append(
                        torch.sum(torch.masked_select(s[i].squeeze(), pointer),
                                  0,
                                  keepdim=True))
                pb = torch.cat(pb, dim=0).squeeze()
                _, max_loc = torch.max(pb, 0)
                pred_answers.append(cands.index_select(0, max_loc))
                pred_locs.append(max_loc)
            pred_answers = torch.cat(pred_answers, dim=0).squeeze()
            #pred_locs = torch.cat(pred_locs, dim=0).squeeze()

        if answers is not None:
            probs = []
            for i, answer in enumerate(answers):
                document = docs_input[i].squeeze()
                pointer = document == answer.expand_as(document)

                this_prob = torch.sum(torch.masked_select(
                    s[i].squeeze(), pointer),
                                      0,
                                      keepdim=True)
                probs.append(this_prob)
            probs = torch.cat(probs, 0).squeeze()

        return pred_answers, probs
示例#7
0
elif MODEL == "gru_attn**2":
    from models.gru_attn import GRUEncoder, GRUDecoder, Seq2Seq, Attention
    attn = Attention(ENC_HID_DIM, DEC_HID_DIM)
    enc = GRUEncoder(INPUT_DIM, ENC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, NLAYERS,
                     ENC_DROPOUT)
    dec = GRUDecoder(OUTPUT_DIM, DEC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM,
                     NLAYERS, DEC_DROPOUT, attn)
    model = Seq2Seq(enc, dec, device).to(device)

    from src.train import train_epoch_gru_attn, evaluate_gru_attn, epoch_time
    train_epoch = train_epoch_gru_attn
    evaluate = evaluate_gru_attn

elif MODEL == "transformer":
    from models.transformer import Encoder, Decoder, Seq2Seq
    enc = Encoder(INPUT_DIM, ENC_HID_DIM, NLAYERS, ENC_HEADS, ENC_PF_DIM,
                  ENC_DROPOUT, device, MAX_LEN)
    dec = Decoder(OUTPUT_DIM, DEC_HID_DIM, NLAYERS, DEC_HEADS, DEC_PF_DIM,
                  DEC_DROPOUT, device, MAX_LEN)
    model = Seq2Seq(enc, dec, SRC_PAD_IDX, TGT_PAD_IDX, device).to(device)

    from src.train import train_epoch_tfmr, evaluate_tfmr, epoch_time
    train_epoch = train_epoch_tfmr
    evaluate = evaluate_tfmr

elif MODEL == "gcn_gru":
    from models.gru_seq2seq import GCNEncoder, GRUDecoder, GCN2Seq
    enc = GCNEncoder(INPUT_DIM, ENC_EMB_DIM, ENC_HID_DIM, NLAYERS, ENC_DROPOUT)
    dec = GRUDecoder(OUTPUT_DIM, DEC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM,
                     NLAYERS, DEC_DROPOUT)
    model = GCN2Seq(enc, dec, device).to(device)
def main(args):
    # Experiment settings
    experiment_name = 'Transformer Baseline'
    log = Log(args.debug)
    hp = Hyperparameters(args)
    device = set_device(log)

    # Record experiment
    debug_warning(args.debug, log)
    log(experiment_name)
    log(vars(hp))

    # Dataloaders
    train_data = json.load(open(hp.data_dir_tr, "r"))

    train_ids, val_ids = splits(train_data, log.debug, SEED)
    train_loader = get_loader(protein_data=train_data,
                              ids=train_ids,
                              hp=hp,
                              device=device,
                              shuffle=True)
    val_loader = get_loader(protein_data=train_data,
                            ids=val_ids,
                            hp=hp,
                            device=device,
                            shuffle=False)

    # Model init
    encoder = Encoder(hp.input_dim,
                      hp.hid_dim,
                      hp.enc_layers,
                      hp.enc_heads,
                      hp.enc_pf_dim,
                      hp.enc_dropout,
                      device,
                      max_length=hp.max_seq_len)
    decoder = Decoder(hp.input_dim,
                      hp.hid_dim,
                      hp.dec_layers,
                      hp.dec_heads,
                      hp.dec_pf_dim,
                      hp.dec_dropout,
                      device,
                      max_length=hp.max_seq_len)
    model = Transformer(encoder, decoder, device).to(device)
    model.apply(initialize_weights)

    # Criterion, optimizer, scheduler
    optimizer = torch.optim.Adam(model.parameters(), lr=hp.lr)
    criterion = nn.MSELoss()
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)

    # Trackers
    best_valid_loss = float('inf')
    losses = []
    early_stop_count = hp.early_stop

    # Training
    for epoch in range(hp.n_epochs):
        memory_usage()
        start_time = time.time()
        train_loss = train(model, train_loader, optimizer, criterion,
                           scheduler, device)
        valid_loss = evaluate(model, val_loader, criterion, device)
        end_time = time.time()
        epoch_mins, epoch_secs = epoch_time(start_time, end_time)

        # Record epoch
        log(f'Epoch: {epoch + 1:02} | Time: {epoch_mins}m {epoch_secs}s | Train Loss: {train_loss:.3f} | Val. Loss: {valid_loss:.3f}'
            )
        losses.append((train_loss, valid_loss))
        graph_losses(losses, log)

        # Track best val, save model
        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss
            early_stop_count = hp.early_stop  # reset early stop
            if not log.debug:
                model_dict = {'model': model.state_dict()}
                location = '{}/model_dict.bin'.format(log.main_dir)
                torch.save(model_dict, location)
                log('Saved model.')
                del model_dict
                torch.cuda.empty_cache()
                # save_model(model, optimizer, log)
        else:
            early_stop_count -= 1
            if early_stop_count == 0:
                log('Stopping early! Epoch: ', epoch)
                break
        del train_loss, valid_loss
        torch.cuda.empty_cache()
        ######## END EPOCH ########

    # Generate and save new PSSM
    tr6614 = json.load(open(hp.data_dir_no_pssm, "r"))
    gen_ids = set_gen_ids(tr6614, log.debug)
    gen_loader = get_loader(protein_data=tr6614,
                            ids=gen_ids,
                            hp=hp,
                            device=device,
                            shuffle=False)
    checkpoint = torch.load('{}/model_dict.bin'.format(log.main_dir),
                            map_location=torch.device(device))
    model.load_state_dict(checkpoint['model'])
    new_pssm = generate_whole(model, gen_loader, device, hp,
                              log)  # [N, 700, 22]
    save_pssm(new_pssm, hp, log)