예제 #1
0
def get_encoder(latent_dim, fckpt='', ker_size=11):
    E = Encoder(z_dim=latent_dim, first_filter_size=ker_size)
    if fckpt and os.path.exists(fckpt):

        ckpt = torch.load(fckpt)
        loaded_sd = ckpt['E']
        try:
            E.load_state_dict(loaded_sd)
        except:
            curr_params = E.state_dict()
            curr_keys = list(curr_params.keys())

            updated_params = {}
            for k, v in loaded_sd.items():
                if 'bn7' in k:
                    newk = k.replace('bn7', 'conv7')
                else:
                    newk = k
                if newk in curr_keys and loaded_sd[k].shape == curr_params[
                        newk].shape:
                    updated_params[newk] = v
                else:
                    print('Failed to load:', k)
            curr_params.update(updated_params)
            E.load_state_dict(curr_params)
    return E
예제 #2
0
def train(dct_size, embed_size=256, hidden_size=512, epochs=10, num_layers=1, save_step=1000, lr=0.001, model_save='model/'):
    encoder = Encoder(embed_size=embed_size).to(device)
    decoder = Decoder(embed_size=embed_size, hidden_size=hidden_size, dct_size=len(dct), num_layers=num_layers).to(device)
    criterion = nn.CrossEntropyLoss()
    params = list(decoder.parameters()) + list(encoder.linear.parameters()) + list(encoder.bn.parameters())
    optimizer = torch.optim.Adam(params, lr=lr)

    for epoch in range(epochs):
        print(f'epoch {epoch+1}/{epochs}: ')
        for i, (images, captions, lengths) in enumerate(tqdm(data_loader)):
        # Set mini-batch dataset
            images = images.to(device)
            captions = captions.to(device)
            targets = pack_padded_sequence(captions, lengths, batch_first=True)[0]
            
            # Forward, backward and optimize
            features = encoder(images)
            outputs = decoder(features, captions, lengths)
            loss = criterion(outputs, targets)
            decoder.zero_grad()
            encoder.zero_grad()
            loss.backward()
            optimizer.step()
            if (i+1) % save_step == 0:
                torch.save(decoder.state_dict(), os.path.join(
                    model_save, 'decoder-{}-{}.ckpt'.format(epoch+1, i+1)))
                torch.save(encoder.state_dict(), os.path.join(
                    model_save, 'encoder-{}-{}.ckpt'.format(epoch+1, i+1)))
예제 #3
0
def train(model_path=None):
    dataloader = DataLoader(Augmentation())
    encoder = Encoder()
    dict_len = len(dataloader.data.dictionary)
    decoder = DecoderWithAttention(dict_len)

    if cuda:
        encoder = encoder.cuda()
        decoder = decoder.cuda()
    # if model_path:
    #   text_generator.load_state_dict(torch.load(model_path))
    train_iter = 1
    encoder_optimizer = torch.optim.Adam(encoder.parameters(),
                                         lr=cfg.encoder_learning_rate)
    decoder_optimizer = torch.optim.Adam(decoder.parameters(),
                                         lr=cfg.decoder_learning_rate)

    val_bleu = list()
    losses = list()
    while True:
        batch_image, batch_label = dataloader.get_next_batch()
        batch_image = torch.from_numpy(batch_image).type(torch.FloatTensor)
        batch_label = torch.from_numpy(batch_label).type(torch.LongTensor)
        if cuda:
            batch_image = batch_image.cuda()
            batch_label = batch_label.cuda()
        # print(batch_image.size())
        # print(batch_label.size())

        print('Training')
        output = encoder(batch_image)
        # print('encoder output:', output.size())
        predictions, alphas = decoder(output, batch_label)

        loss = cal_loss(predictions, batch_label, alphas, 1)

        decoder_optimizer.zero_grad()
        encoder_optimizer.zero_grad()
        loss.backward()
        decoder_optimizer.step()
        encoder_optimizer.step()

        print('Iter', train_iter, '| loss:',
              loss.cpu().data.numpy(), '| batch size:', cfg.batch_size,
              '| encoder learning rate:', cfg.encoder_learning_rate,
              '| decoder learning rate:', cfg.decoder_learning_rate)
        losses.append(loss.cpu().data.numpy())
        if train_iter % cfg.save_model_iter == 0:
            val_bleu.append(val_eval(encoder, decoder, dataloader))
            torch.save(
                encoder.state_dict(), './models/train/encoder_' +
                cfg.pre_train_model + '_' + str(train_iter) + '.pkl')
            torch.save(decoder.state_dict(),
                       './models/train/decoder_' + str(train_iter) + '.pkl')
            np.save('./result/train_bleu4.npy', val_bleu)
            np.save('./result/losses.npy', losses)

        if train_iter == cfg.train_iter:
            break
        train_iter += 1
예제 #4
0
def train():
    with open(vocab_path, 'rb') as f:
        vocab = pickle.load(f)

    vocab_size = len(vocab)
    print('vocab_size:', vocab_size)

    dataloader = get_loader(image_dir,
                            caption_path,
                            vocab,
                            batch_size,
                            crop_size,
                            shuffle=True,
                            num_workers=num_workers)

    encoder = Encoder(embedding_size).to(device)
    decoder = Decoder(vocab_size, embedding_size, lstm_size).to(device)
    if os.path.exists(encoder_path):
        encoder.load_state_dict(torch.load(encoder_path))
    if os.path.exists(decoder_path):
        decoder.load_state_dict(torch.load(decoder_path))

    loss_fn = torch.nn.CrossEntropyLoss()
    parameters = list(encoder.fc.parameters()) + list(
        encoder.bn.parameters()) + list(decoder.parameters())
    optimizer = torch.optim.Adam(parameters,
                                 lr=learning_rate,
                                 betas=(0.9, 0.99))

    num_steps = len(dataloader)
    for epoch in range(num_epochs):
        for index, (imgs, captions, lengths) in enumerate(dataloader):
            imgs = imgs.to(device)
            captions = captions.to(device)
            targets = pack_padded_sequence(
                captions, lengths,
                batch_first=True)[0]  # the tailing [0] is necessary

            features = encoder(imgs)
            y_predicted = decoder(features, captions, lengths)
            loss = loss_fn(y_predicted, targets)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if index % log_every == 0:
                print(
                    'Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Perplexity: {:5.4f}'
                    .format(epoch, num_epochs, index, num_steps, loss.item(),
                            np.exp(loss.item())))

            if index % save_every == 0 and index != 0:
                print('Start saving encoder')
                torch.save(encoder.state_dict(), encoder_path)
                print('Start saving decoder')
                torch.save(decoder.state_dict(), decoder_path)
예제 #5
0
class MyModel(TorchModelV2, nn.Module):
    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
                              model_config, name)
        nn.Module.__init__(self)
        custom_config = model_config["custom_options"]
        latent_size = custom_config['latent_size']

        self.main = Encoder(latent_size=latent_size)

        if custom_config['encoder_path'] is not None:
            # saved checkpoints could contain extra weights such as linear_logsigma
            weights = torch.load(custom_config['encoder_path'],
                                 map_location=torch.device('cpu'))
            for k in list(weights.keys()):
                if k not in self.main.state_dict().keys():
                    del weights[k]
            self.main.load_state_dict(weights)
            print("Loaded Weights")
        else:
            print("No Load Weights")

        self.critic = nn.Sequential(nn.Linear(latent_size, 400), nn.ReLU(),
                                    nn.Linear(400, 300), nn.ReLU(),
                                    nn.Linear(300, 1))
        self.actor = nn.Sequential(nn.Linear(latent_size, 400), nn.ReLU(),
                                   nn.Linear(400, 300), nn.ReLU())
        self.alpha_head = nn.Sequential(nn.Linear(300, 3), nn.Softplus())
        self.beta_head = nn.Sequential(nn.Linear(300, 3), nn.Softplus())
        self._cur_value = None
        self.train_encoder = custom_config['train_encoder']
        print("Train Encoder: ", self.train_encoder)

    @override(TorchModelV2)
    def forward(self, input_dict, state, seq_lens):
        features = self.main(input_dict['obs'].float())
        if not self.train_encoder:
            features = features.detach()  # not train the encoder

        actor_features = self.actor(features)
        alpha = self.alpha_head(actor_features) + 1
        beta = self.beta_head(actor_features) + 1
        logits = torch.cat([alpha, beta], dim=1)
        self._cur_value = self.critic(features).squeeze(1)

        return logits, state

    @override(TorchModelV2)
    def value_function(self):
        assert self._cur_value is not None, 'Must call forward() first'
        return self._cur_value
예제 #6
0
class MyModel(TorchModelV2, nn.Module):
    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
                              model_config, name)
        nn.Module.__init__(self)
        custom_config = model_config['custom_options']

        self.main = Encoder()

        if custom_config['encoder_path'] is not None:
            print("Load Trained Encoder")
            # saved checkpoints could contain extra weights such as linear_logsigma
            weights = torch.load(custom_config['encoder_path'],
                                 map_location={'cuda:0': 'cpu'})
            for k in list(weights.keys()):
                if k not in self.main.state_dict().keys():
                    del weights[k]
            self.main.load_state_dict(weights)

        self.critic = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(),
                                    nn.Linear(1024, 256), nn.ReLU(),
                                    nn.Linear(256, 1))
        self.actor = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(),
                                   nn.Linear(1024, 256), nn.ReLU(),
                                   nn.Linear(256, 3), nn.Sigmoid())
        self.actor_logstd = nn.Parameter(torch.zeros(3), requires_grad=True)
        self._cur_value = None
        print("Train Encoder:", custom_config['train_encoder'])
        self.train_encoder = custom_config['train_encoder']

    @override(TorchModelV2)
    def forward(self, input_dict, state, seq_lens):
        features = self.main(input_dict['obs'].float())
        if not self.train_encoder:
            features = features.detach()  # not train the encoder

        actor_mu = self.actor(features)  # Bx3
        batch_size = actor_mu.shape[0]
        actor_logstd = torch.stack(batch_size * [self.actor_logstd],
                                   dim=0)  # Bx3
        logits = torch.cat([actor_mu, actor_logstd], dim=1)
        self._cur_value = self.critic(features).squeeze(1)

        return logits, state

    @override(TorchModelV2)
    def value_function(self):
        assert self._cur_value is not None, 'Must call forward() first'
        return self._cur_value
예제 #7
0
                loss_z_val.update(loss_z_cur.item())
                loss_recon_val.update(loss_recon_cur.item())
                loss_classify_val.update(loss_classify_cur.item())
                loss_val.update(loss_cur.item())
                batch_time = time.time() - batch_start_time
                bar(batch_idx,
                    len(val_loader),
                    "Epoch: {:3d} | ".format(epoch),
                    ' | time {:.3f} | loss_val {:.5f} | loss_z_val {:.5f} | loss_recon_val {:.5f} | loss_classify_val {:.5f}  |'
                    .format(batch_time, loss_val.val, loss_z_val.val,
                            loss_recon_val.val, loss_classify_val.val),
                    end_string="")

            log_entry_val = '\n| end of epoch {:3d} | time: {:5.5f}s | valid loss {:.5f} | valid recon loss {:.5f} | valid classify loss {:.5f} | valid psnr {:5.2f}'.format(
                epoch, (time.time() - epoch_start_time), loss_val.avg,
                loss_recon_val.avg, loss_classify_val.avg, psnr_val.avg)
            print(log_entry_val)
            with open(os.path.join(args.save, 'val.log'), 'a') as f:
                f.write(log_entry_val)

        if epoch % args.save_every == 0:
            states = {
                'epoch': epoch,
                'encoder': encoder.state_dict(),
                'decoder': decoder.state_dict(),
                'classifier': classifier.state_dict()
            }
            torch.save(
                states,
                os.path.join(args.save, 'checkpoint_' + str(epoch) + '.pth'))
예제 #8
0
def main():
    epoch = 1000
    batch_size = 64
    hidden_dim = 300
    use_cuda = True

    encoder = Encoder(num_words, hidden_dim)
    if args.attn:
        attn_model = 'dot'
        decoder = LuongAttnDecoderRNN(attn_model, hidden_dim, num_words)
    else:
        decoder = DecoderRhyme(hidden_dim, num_words, num_target_lengths,
                               num_rhymes)

    if args.train:
        weight = torch.ones(num_words)
        weight[word2idx_mapping[PAD_TOKEN]] = 0
        if use_cuda:
            encoder = encoder.cuda()
            decoder = decoder.cuda()
            weight = weight.cuda()
        encoder_optimizer = Adam(encoder.parameters(), lr=0.001)
        decoder_optimizer = Adam(decoder.parameters(), lr=0.001)
        criterion = nn.CrossEntropyLoss(weight=weight)

        np.random.seed(1124)
        order = np.arange(len(train_data))

        best_loss = 1e10
        best_epoch = 0

        for e in range(epoch):
            #if e - best_epoch > 20: break

            np.random.shuffle(order)
            shuffled_train_data = train_data[order]
            shuffled_x_lengths = input_lengths[order]
            shuffled_y_lengths = target_lengths[order]
            shuffled_y_rhyme = target_rhymes[order]
            train_loss = 0
            valid_loss = 0
            for b in tqdm(range(int(len(order) // batch_size))):
                #print(b, '\r', end='')
                batch_x = torch.LongTensor(
                    shuffled_train_data[b * batch_size:(b + 1) *
                                        batch_size][:, 0].tolist()).t()
                batch_y = torch.LongTensor(
                    shuffled_train_data[b * batch_size:(b + 1) *
                                        batch_size][:, 1].tolist()).t()
                batch_x_lengths = shuffled_x_lengths[b * batch_size:(b + 1) *
                                                     batch_size]
                batch_y_lengths = shuffled_y_lengths[b * batch_size:(b + 1) *
                                                     batch_size]
                batch_y_rhyme = shuffled_y_rhyme[b * batch_size:(b + 1) *
                                                 batch_size]

                if use_cuda:
                    batch_x, batch_y = batch_x.cuda(), batch_y.cuda()

                train_loss += train(batch_x, batch_y, batch_y_lengths,
                                    max(batch_y_lengths), batch_y_rhyme,
                                    encoder, decoder, encoder_optimizer,
                                    decoder_optimizer, criterion, use_cuda,
                                    False)

            train_loss /= b
            '''
            for b in range(len(valid_data) // batch_size):
                batch_x = torch.LongTensor(valid_data[b*batch_size: (b+1)*batch_size][:, 0].tolist()).t()
                batch_y = torch.LongTensor(valid_data[b*batch_size: (b+1)*batch_size][:, 1].tolist()).t()
                if use_cuda:
                    batch_x, batch_y = batch_x.cuda(), batch_y.cuda()

                valid_loss += train(batch_x, batch_y, max_seqlen, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion, use_cuda, True)
            valid_loss /= b
            '''
            print(
                "epoch {}, train_loss {:.4f}, valid_loss {:.4f}, best_epoch {}, best_loss {:.4f}"
                .format(e, train_loss, valid_loss, best_epoch, best_loss))
            '''
            if valid_loss < best_loss:
                best_loss = valid_loss
                best_epoch = e
                torch.save(encoder.state_dict(), args.encoder_path + '.best')
                torch.save(decoder.state_dict(), args.decoder_path + '.best')
            '''
            torch.save(encoder.state_dict(), args.encoder_path)
            torch.save(decoder.state_dict(), args.decoder_path)
        print(encoder)
        print(decoder)
        print("==============")

    else:
        encoder.load_state_dict(torch.load(
            args.encoder_path))  #, map_location=torch.device('cpu')))
        decoder.load_state_dict(torch.load(
            args.decoder_path))  #, map_location=torch.device('cpu')))
        print(encoder)
        print(decoder)

    predict(encoder, decoder)
예제 #9
0
            loss_imgs = criterion(fake, real)
            loss_features = criterion(fake_features, real_features)
            enc_loss = loss_imgs + kappa * loss_features

            enc_loss.backward()
            opt_enc.step()
            # if i % CRITIC_ITERATIONS == 0:
            # e_losses.append(e_loss)
        # enc.eval()
        writer_enc.add_scalar('enc_loss', enc_loss.item(), epoch)
        print(f"[Epoch {epoch:{padding_epoch}}/{enc_epochs}] "
              f"[Batch {batch_idx:{padding_i}}/{len(loader)}] "
              f"[E loss: {enc_loss.item():3f}]")
        # step += 1
    torch.save(
        enc.state_dict(),
        '/home/mihael/ML/GANs_n_Anomalies/torch_GAN/f-AnoGAN_w_WGAN-GP/models/netE_%d.pth'
        % enc_epochs)

evaluate = True
loader = DataLoader(
    dataset,
    batch_size=1,
    shuffle=True,
)
test_loader = DataLoader(
    test_dataset,
    batch_size=1,
    shuffle=True,
)
if evaluate:
예제 #10
0
if not device == "cpu":
    net_g = nn.DataParallel(net_g)
    net_h = nn.DataParallel(net_h)
    net_DCD = nn.DataParallel(net_DCD)

for epoch in range(num_ep_init_gh):
    for data, label in s_trainloader:
        data, label = data.to(device), label.to(device)
        optimizer.zero_grad()
        pred = net_h(net_g(data))
        loss = loss_func(pred, label)
        loss.backward()
        optimizer.step()
    if epoch % 10 == 0:
        print("epoch{} has finished".format(epoch))
torch.save(net_g.state_dict(), "model_g")
torch.save(net_h.state_dict(), "model_h")
with torch.no_grad():
    acc = 0
    total = 0
    for te_data, te_label in s_testloader:
        te_data, te_label = te_data.to(device), te_label.to(device)
        output = net_h(net_g(te_data))
        pred = torch.argmax(output, dim=1)
        acc += (pred == te_label).sum().item() / len(te_label)
    acc = acc / len(s_testloader)
    print("accuracy in initial train of g and h(source):{}".format(acc))
    acc = 0
    total = 0
    for te_data, te_label in t_testloader:
        te_label = te_label.type(torch.LongTensor)
예제 #11
0
def train():
    # 1.数据集整理
    data = json.load(open(Config.train_data_path, 'r'))

    input_data = data['input_data']
    input_len = data['input_len']
    output_data = data['output_data']
    mask_data = data['mask']
    output_len = data['output_len']

    total_len = len(input_data)
    step = total_len // Config.batch_size

    # 词嵌入部分
    embedding = nn.Embedding(Config.vocab_size,
                             Config.hidden_size,
                             padding_idx=Config.PAD)

    # 2. 模型准备
    encoder = Encoder(embedding)
    attn_model = 'dot'
    decoder = Decoder(
        attn_model,
        embedding,
    )

    encoder_optimizer = torch.optim.Adam(encoder.parameters(),
                                         lr=Config.learning_rate)
    decoder_optimizer = torch.optim.Adam(decoder.parameters(),
                                         lr=Config.learning_rate)

    for epoch in range(Config.num_epochs):
        for i in range(step - 1):
            start_time = time.time()
            encoder_optimizer.zero_grad()
            decoder_optimizer.zero_grad()

            input_ids = torch.LongTensor(
                input_data[i * Config.batch_size:(i + 1) *
                           Config.batch_size]).to(Config.device)
            inp_len = torch.LongTensor(
                input_len[i * Config.batch_size:(i + 1) *
                          Config.batch_size]).to(Config.device)
            output_ids = torch.LongTensor(
                output_data[i * Config.batch_size:(i + 1) *
                            Config.batch_size]).to(Config.device)
            mask = torch.BoolTensor(mask_data[i * Config.batch_size:(i + 1) *
                                              Config.batch_size]).to(
                                                  Config.device)
            out_len = output_len[i * Config.batch_size:(i + 1) *
                                 Config.batch_size]

            max_ans_len = max(out_len)

            mask = mask.permute(1, 0)
            output_ids = output_ids.permute(1, 0)
            encoder_outputs, hidden = encoder(input_ids, inp_len)
            encoder_outputs = encoder_outputs.permute(1, 0, 2)
            decoder_hidden = hidden.unsqueeze(0)

            # 创建解码的初始输入 (为一个batch中的每条数创建SOS)
            decoder_input = torch.LongTensor(
                [[Config.SOS for _ in range(Config.batch_size)]])
            decoder_input = decoder_input.to(Config.device)

            # Determine if we are using teacher forcing this iteration
            teacher_forcing_ratio = 0.3
            use_teacher_forcing = True if random.random(
            ) < teacher_forcing_ratio else False

            loss = 0
            print_losses = []
            n_totals = 0
            if use_teacher_forcing:
                # 这种是解码的每步我们输入上一步的真实标签
                for t in range(max_ans_len):
                    decoder_output, decoder_hidden = decoder(
                        decoder_input, decoder_hidden, encoder_outputs)
                    # print(decoder_output.size())  # torch.Size([2, 2672])
                    # print(decoder_hidden.size())   # torch.Size([1, 2, 512])

                    decoder_input = output_ids[t].view(1, -1)
                    # 计算损失
                    mask_loss, nTotal = maskNLLLoss(decoder_output,
                                                    output_ids[t], mask[t])
                    # print('1', mask_loss)
                    loss += mask_loss
                    print_losses.append(mask_loss.item() * nTotal)
                    n_totals += nTotal
            else:
                # 这种是解码的每步输入是上一步的预测结果
                for t in range(max_ans_len):
                    decoder_output, decoder_hidden = decoder(
                        decoder_input, decoder_hidden, encoder_outputs)

                    _, topi = decoder_output.topk(1)
                    decoder_input = torch.LongTensor(
                        [[topi[i][0] for i in range(Config.batch_size)]])
                    decoder_input = decoder_input.to(Config.device)
                    # Calculate and accumulate loss
                    mask_loss, nTotal = maskNLLLoss(decoder_output,
                                                    output_ids[t], mask[t])
                    # print('2', mask_loss)
                    loss += mask_loss
                    print_losses.append(mask_loss.item() * nTotal)
                    n_totals += nTotal

            # Perform backpropatation
            loss.backward()

            # 梯度裁剪
            _ = nn.utils.clip_grad_norm_(encoder.parameters(), Config.clip)
            _ = nn.utils.clip_grad_norm_(decoder.parameters(), Config.clip)

            # Adjust model weights
            encoder_optimizer.step()
            decoder_optimizer.step()
            avg_loss = sum(print_losses) / n_totals

            time_str = datetime.datetime.now().isoformat()
            log_str = 'time:{}, epoch:{}, step:{}, loss:{:5f}, spend_time:{:6f}'.format(
                time_str, epoch, i, avg_loss,
                time.time() - start_time)
            rainbow(log_str)

        if epoch % 1 == 0:
            save_path = './save_model/'
            if not os.path.exists(save_path):
                os.makedirs(save_path)

            torch.save(
                {
                    'epoch': epoch,
                    'encoder': encoder.state_dict(),
                    'decoder': decoder.state_dict(),
                    'en_opt': encoder_optimizer.state_dict(),
                    'de_opt': decoder_optimizer.state_dict(),
                    'loss': avg_loss,
                    'embedding': embedding.state_dict()
                },
                os.path.join(
                    save_path,
                    'epoch{}_{}_model.tar'.format(epoch, 'checkpoint')))
예제 #12
0
def train(args, logger):
    task_time = time.strftime("%Y-%m-%d %H:%M", time.localtime())
    Path("./saved_models/").mkdir(parents=True, exist_ok=True)
    Path("./pretrained_models/").mkdir(parents=True, exist_ok=True)
    MODEL_SAVE_PATH = './saved_models/'
    Pretrained_MODEL_PATH = './pretrained_models/'
    get_model_name = lambda part: f'{part}-{args.data}-{args.tasks}-{args.prefix}.pth'
    get_pretrain_model_name = lambda part: f'{part}-{args.data}-LP-{args.prefix}.pth'
    device_string = 'cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu >=0 else 'cpu'
    print('Model trainging with '+device_string)
    device = torch.device(device_string)
    


    g = load_graphs(f"./data/{args.data}.dgl")[0][0]
    
    efeat_dim = g.edata['feat'].shape[1]
    nfeat_dim = efeat_dim


    train_loader, val_loader, test_loader, num_val_samples, num_test_samples = dataloader(args, g)


    encoder = Encoder(args, nfeat_dim, n_head=args.n_head, dropout=args.dropout).to(device)
    decoder = Decoder(args, nfeat_dim).to(device)
    msg2mail = Msg2Mail(args, nfeat_dim)
    fraud_sampler = frauder_sampler(g)

    optimizer = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=args.lr, weight_decay=args.weight_decay)
    scheduler_lr = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=40)
    if args.warmup:
        scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=3, after_scheduler=scheduler_lr)
        optimizer.zero_grad()
        optimizer.step()
    loss_fcn = torch.nn.BCEWithLogitsLoss()

    loss_fcn = loss_fcn.to(device)

    early_stopper = EarlyStopMonitor(logger=logger, max_round=args.patience, higher_better=True)

    if args.pretrain:
        logger.info(f'Loading the linkpred pretrained attention based encoder model')
        encoder.load_state_dict(torch.load(Pretrained_MODEL_PATH+get_pretrain_model_name('Encoder')))

    for epoch in range(args.n_epoch):
        # reset node state
        g.ndata['mail'] = torch.zeros((g.num_nodes(), args.n_mail, nfeat_dim+2), dtype=torch.float32) 
        g.ndata['feat'] = torch.zeros((g.num_nodes(), nfeat_dim), dtype=torch.float32) # init as zero, people can init it using others.
        g.ndata['last_update'] = torch.zeros((g.num_nodes()), dtype=torch.float32) 
        encoder.train()
        decoder.train()
        start_epoch = time.time()
        m_loss = []
        logger.info('start {} epoch, current optim lr is {}'.format(epoch, optimizer.param_groups[0]['lr']))
        for batch_idx, (input_nodes, pos_graph, neg_graph, blocks, frontier, current_ts) in enumerate(train_loader):
            

            pos_graph = pos_graph.to(device)
            neg_graph = neg_graph.to(device) if neg_graph is not None else None
            

            if not args.no_time or not args.no_pos:
                current_ts, pos_ts, num_pos_nodes = get_current_ts(args, pos_graph, neg_graph)
                pos_graph.ndata['ts'] = current_ts
            else:
                current_ts, pos_ts, num_pos_nodes = None, None, None
            
            _ = dgl.add_reverse_edges(neg_graph) if neg_graph is not None else None
            emb, _ = encoder(dgl.add_reverse_edges(pos_graph), _, num_pos_nodes)
            if batch_idx != 0:
                if 'LP' not in args.tasks and args.balance:
                    neg_graph = fraud_sampler.sample_fraud_event(g, args.bs//5, current_ts.max().cpu()).to(device)
                logits, labels = decoder(emb, pos_graph, neg_graph)

                loss = loss_fcn(logits, labels)
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                m_loss.append(loss.item())


            # MSG Passing
            with torch.no_grad():
                mail = msg2mail.gen_mail(args, emb, input_nodes, pos_graph, frontier, 'train')

                if not args.no_time:
                    g.ndata['last_update'][pos_graph.ndata[dgl.NID][:num_pos_nodes]] = pos_ts.to('cpu')
                g.ndata['feat'][pos_graph.ndata[dgl.NID]] = emb.to('cpu')
                g.ndata['mail'][input_nodes] = mail
            if batch_idx % 100 == 1:
                gpu_mem = torch.cuda.max_memory_allocated() / 1.074e9 if torch.cuda.is_available() and args.gpu >= 0 else 0
                torch.cuda.empty_cache()
                mem_perc = psutil.virtual_memory().percent
                cpu_perc = psutil.cpu_percent(interval=None)
                output_string = f'Epoch {epoch} | Step {batch_idx}/{len(train_loader)} | CPU {cpu_perc:.1f}% | Sys Mem {mem_perc:.1f}% | GPU Mem {gpu_mem:.4f}GB '
                
                output_string += f'| {args.tasks} Loss {np.mean(m_loss):.4f}'

                logger.info(output_string)

        total_epoch_time = time.time() - start_epoch
        logger.info(' training epoch: {} took {:.4f}s'.format(epoch, total_epoch_time))
        val_ap, val_auc, val_acc, val_loss = eval_epoch(args, logger, g, val_loader, encoder, decoder, msg2mail, loss_fcn, device, num_val_samples)
        logger.info('Val {} Task | ap: {:.4f} | auc: {:.4f} | acc: {:.4f} | Loss: {:.4f}'.format(args.tasks, val_ap, val_auc, val_acc, val_loss))

        if args.warmup:
            scheduler_warmup.step(epoch)
        else:
            scheduler_lr.step()

        early_stopper_metric = val_ap if 'LP' in args.tasks else val_auc

        if early_stopper.early_stop_check(early_stopper_metric):
            logger.info('No improvement over {} epochs, stop training'.format(early_stopper.max_round))
            logger.info(f'Loading the best model at epoch {early_stopper.best_epoch}')
            encoder.load_state_dict(torch.load(MODEL_SAVE_PATH+get_model_name('Encoder')))
            decoder.load_state_dict(torch.load(MODEL_SAVE_PATH+get_model_name('Decoder')))

            test_result = [early_stopper.best_ap, early_stopper.best_auc, early_stopper.best_acc, early_stopper.best_loss]
            break

        test_ap, test_auc, test_acc, test_loss = eval_epoch(args, logger, g, test_loader, encoder, decoder, msg2mail, loss_fcn, device, num_test_samples)
        logger.info('Test {} Task | ap: {:.4f} | auc: {:.4f} | acc: {:.4f} | Loss: {:.4f}'.format(args.tasks, test_ap, test_auc, test_acc, test_loss))
        test_result = [test_ap, test_auc, test_acc, test_loss]

        if early_stopper.best_epoch == epoch: 
            early_stopper.best_ap = test_ap
            early_stopper.best_auc = test_auc
            early_stopper.best_acc = test_acc
            early_stopper.best_loss = test_loss
            logger.info(f'Saving the best model at epoch {early_stopper.best_epoch}')
            torch.save(encoder.state_dict(), MODEL_SAVE_PATH+get_model_name('Encoder'))
            torch.save(decoder.state_dict(), MODEL_SAVE_PATH+get_model_name('Decoder'))
예제 #13
0
파일: train.py 프로젝트: entn-at/G2P-2
        # decoder
        T, N = p.size()
        outputs = []
        hidden = torch.ones(1, N,
                            ModelConfig.hidden_size).to(TrainConfig.device)
        for t in range(T - 1):
            out, hidden, _ = decoder_model(p[t:t + 1], enc, hidden)
            outputs.append(out)
        outputs = torch.cat(outputs)

        # flat Time and Batch, calculate loss
        outputs = outputs.view((T - 1) * N, -1)
        p = p[1:]  # trim first phoneme
        p = p.view(-1)
        loss = criterion(outputs, p)

        # updata weights
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        log.add_scalar('loss', loss.item(), counter)
        counter += 1

    # save model
    torch.save(encoder_model.state_dict(),
               f'models/{DataConfig.language}/encoder_e{e+1:02d}.pth')
    torch.save(decoder_model.state_dict(),
               f'models/{DataConfig.language}/decoder_e{e+1:02d}.pth')
예제 #14
0
def main(args):

    #create a writer
    writer = SummaryWriter('loss_plot_' + args.mode, comment='test')
    # Create model directory
    if not os.path.exists(args.model_path):
        os.makedirs(args.model_path)

    # Image preprocessing, normalization for the pretrained resnet
    transform = T.Compose([
        T.Resize((224, 224)),
        T.ToTensor(),
        T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    # Load vocabulary wrapper
    with open(args.vocab_path, 'rb') as f:
        vocab = pickle.load(f)

    val_length = len(os.listdir(args.image_dir_val))

    # Build data loader
    data_loader = get_loader(args.image_dir,
                             args.caption_path,
                             vocab,
                             transform,
                             args.batch_size,
                             shuffle=True,
                             num_workers=args.num_workers)

    data_loader_val = get_loader(args.image_dir_val,
                                 args.caption_path_val,
                                 vocab,
                                 transform,
                                 args.batch_size,
                                 shuffle=True,
                                 num_workers=args.num_workers)

    # Build the model
    # if no-attention model is chosen:
    if args.model_type == 'no_attention':
        encoder = Encoder(args.embed_size).to(device)
        decoder = Decoder(args.embed_size, args.hidden_size, len(vocab),
                          args.num_layers).to(device)
        criterion = nn.CrossEntropyLoss()

    # if attention model is chosen:
    elif args.model_type == 'attention':
        encoder = EncoderAtt(encoded_image_size=9).to(device)
        decoder = DecoderAtt(vocab, args.encoder_dim, args.hidden_size,
                             args.attention_dim, args.embed_size,
                             args.dropout_ratio, args.alpha_c).to(device)

    # if transformer model is chosen:
    elif args.model_type == 'transformer':
        model = Transformer(len(vocab), args.embed_size,
                            args.transformer_layers, 8,
                            args.dropout_ratio).to(device)

        encoder_optimizer = torch.optim.Adam(params=filter(
            lambda p: p.requires_grad, model.encoder.parameters()),
                                             lr=args.learning_rate_enc)
        decoder_optimizer = torch.optim.Adam(params=filter(
            lambda p: p.requires_grad, model.decoder.parameters()),
                                             lr=args.learning_rate_dec)
        criterion = nn.CrossEntropyLoss(ignore_index=vocab.word2idx['<pad>'])

    else:
        print('Select model_type attention or no_attention')

    # if model is not transformer: additional step in encoder is needed: freeze lower layers of resnet if args.fine_tune == True
    if args.model_type != 'transformer':
        decoder_optimizer = torch.optim.Adam(params=filter(
            lambda p: p.requires_grad, decoder.parameters()),
                                             lr=args.learning_rate_dec)
        encoder.fine_tune(args.fine_tune)
        encoder_optimizer = torch.optim.Adam(params=filter(
            lambda p: p.requires_grad, encoder.parameters()),
                                             lr=args.learning_rate_enc)

    # initialize lists to store results:
    loss_train = []
    loss_val = []
    loss_val_epoch = []
    loss_train_epoch = []

    bleu_res_list = []
    cider_res_list = []
    rouge_res_list = []

    results = {}

    # calculate total steps fot train and validation
    total_step = len(data_loader)
    total_step_val = len(data_loader_val)

    #For each epoch
    for epoch in tqdm(range(args.num_epochs)):

        loss_val_iter = []
        loss_train_iter = []

        # set model to train mode
        if args.model_type != 'transformer':
            encoder.train()
            decoder.train()
        else:
            model.train()

        # for each entry in data_loader
        for i, (images, captions, lengths) in tqdm(enumerate(data_loader)):
            # load images and captions to device
            images = images.to(device)
            captions = captions.to(device)
            # Forward, backward and optimize

            # forward and backward path is different dependent of model type:
            if args.model_type == 'no_attention':
                # get features from encoder
                features = encoder(images)
                # pad targergets to a length
                targets = pack_padded_sequence(captions,
                                               lengths,
                                               batch_first=True)[0]
                # get output from decoder
                outputs = decoder(features, captions, lengths)
                # calculate loss
                loss = criterion(outputs, targets)

                # optimizer and backward step
                decoder_optimizer.zero_grad()
                decoder_optimizer.zero_grad()
                loss.backward()
                decoder_optimizer.step()
                encoder_optimizer.step()

            elif args.model_type == 'attention':

                # get features from encoder
                features = encoder(images)

                # get targets - starting from 2 word in captions
                #(the model not sequantial, so targets are predicted in parallel- no need to predict first word in captions)

                targets = captions[:, 1:]
                # decode length = length-1 for each caption
                decode_lengths = [length - 1 for length in lengths]
                #flatten targets
                targets = targets.reshape(targets.shape[0] * targets.shape[1])

                sampled_caption = []

                # get scores and alphas from decoder
                scores, alphas = decoder(features, captions, decode_lengths)

                scores = scores.view(-1, scores.shape[-1])

                #predicted = prediction with maximum score
                _, predicted = torch.max(scores, dim=1)

                # calculate loss
                loss = decoder.loss(scores, targets, alphas)

                # optimizer and backward step
                decoder_optimizer.zero_grad()
                decoder_optimizer.zero_grad()
                loss.backward()
                decoder_optimizer.step()
                encoder_optimizer.step()

            elif args.model_type == 'transformer':

                # input is captions without last word
                trg_input = captions[:, :-1]
                # create mask
                trg_mask = create_masks(trg_input)

                # get scores from model
                scores = model(images, trg_input, trg_mask)
                scores = scores.view(-1, scores.shape[-1])

                # get targets - starting from 2 word in captions
                targets = captions[:, 1:]

                #predicted = prediction with maximum score
                _, predicted = torch.max(scores, dim=1)

                # calculate loss
                loss = criterion(
                    scores,
                    targets.reshape(targets.shape[0] * targets.shape[1]))

                #forward and backward path
                decoder_optimizer.zero_grad()
                decoder_optimizer.zero_grad()
                loss.backward()
                decoder_optimizer.step()
                encoder_optimizer.step()

            else:
                print('Select model_type attention or no_attention')

            # append results to loss lists and writer
            loss_train_iter.append(loss.item())
            loss_train.append(loss.item())
            writer.add_scalar('Loss/train/iterations', loss.item(), i + 1)

            # Print log info
            if i % args.log_step == 0:
                print(
                    'Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Perplexity: {:5.4f}'
                    .format(epoch, args.num_epochs, i, total_step, loss.item(),
                            np.exp(loss.item())))

        print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Perplexity: {:5.4f}'.
              format(epoch, args.num_epochs, i, total_step, loss.item(),
                     np.exp(loss.item())))

        #append mean of last 10 batches as approximate epoch loss
        loss_train_epoch.append(np.mean(loss_train_iter[-10:]))

        writer.add_scalar('Loss/train/epoch', np.mean(loss_train_iter[-10:]),
                          epoch + 1)

        #save model
        if args.model_type != 'transformer':
            torch.save(
                decoder.state_dict(),
                os.path.join(
                    args.model_path,
                    'decoder_' + args.mode + '_{}.ckpt'.format(epoch + 1)))
            torch.save(
                encoder.state_dict(),
                os.path.join(
                    args.model_path,
                    'decoder_' + args.mode + '_{}.ckpt'.format(epoch + 1)))

        else:
            torch.save(
                model.state_dict(),
                os.path.join(
                    args.model_path,
                    'model_' + args.mode + '_{}.ckpt'.format(epoch + 1)))
        np.save(
            os.path.join(args.predict_json,
                         'loss_train_temp_' + args.mode + '.npy'), loss_train)

        #validate model:
        # set model to eval mode:
        if args.model_type != 'transformer':
            encoder.eval()
            decoder.eval()
        else:
            model.eval()
        total_step = len(data_loader_val)

        # set no_grad mode:
        with torch.no_grad():
            # for each entry in data_loader
            for i, (images, captions,
                    lengths) in tqdm(enumerate(data_loader_val)):
                targets = pack_padded_sequence(captions,
                                               lengths,
                                               batch_first=True)[0]
                images = images.to(device)
                captions = captions.to(device)

                # forward and backward path is different dependent of model type:
                if args.model_type == 'no_attention':
                    features = encoder(images)
                    outputs = decoder(features, captions, lengths)
                    loss = criterion(outputs, targets)

                elif args.model_type == 'attention':

                    features = encoder(images)
                    sampled_caption = []
                    targets = captions[:, 1:]
                    decode_lengths = [length - 1 for length in lengths]
                    targets = targets.reshape(targets.shape[0] *
                                              targets.shape[1])

                    scores, alphas = decoder(features, captions,
                                             decode_lengths)

                    _, predicted = torch.max(scores, dim=1)

                    scores = scores.view(-1, scores.shape[-1])

                    sampled_caption = []

                    loss = decoder.loss(scores, targets, alphas)

                elif args.model_type == 'transformer':

                    trg_input = captions[:, :-1]
                    trg_mask = create_masks(trg_input)
                    scores = model(images, trg_input, trg_mask)
                    scores = scores.view(-1, scores.shape[-1])
                    targets = captions[:, 1:]

                    _, predicted = torch.max(scores, dim=1)

                    loss = criterion(
                        scores,
                        targets.reshape(targets.shape[0] * targets.shape[1]))

                #display results
                if i % args.log_step == 0:
                    print(
                        'Epoch [{}/{}], Step [{}/{}], Validation Loss: {:.4f}, Validation Perplexity: {:5.4f}'
                        .format(epoch, args.num_epochs, i, total_step_val,
                                loss.item(), np.exp(loss.item())))

                # append results to loss lists and writer
                loss_val.append(loss.item())
                loss_val_iter.append(loss.item())

                writer.add_scalar('Loss/validation/iterations', loss.item(),
                                  i + 1)

        np.save(
            os.path.join(args.predict_json, 'loss_val_' + args.mode + '.npy'),
            loss_val)

        print(
            'Epoch [{}/{}], Step [{}/{}], Validation Loss: {:.4f}, Validation Perplexity: {:5.4f}'
            .format(epoch, args.num_epochs, i, total_step_val, loss.item(),
                    np.exp(loss.item())))

        # results: epoch validation loss

        loss_val_epoch.append(np.mean(loss_val_iter))
        writer.add_scalar('Loss/validation/epoch', np.mean(loss_val_epoch),
                          epoch + 1)

        #predict captions:
        filenames = os.listdir(args.image_dir_val)

        predicted = {}

        for file in tqdm(filenames):
            if file == '.DS_Store':
                continue
            # Prepare an image
            image = load_image(os.path.join(args.image_dir_val, file),
                               transform)
            image_tensor = image.to(device)

            # Generate caption starting with <start> word

            # procedure is different for each model type
            if args.model_type == 'attention':

                features = encoder(image_tensor)
                sampled_ids, _ = decoder.sample(features)
                sampled_ids = sampled_ids[0].cpu().numpy()
                #start sampled_caption with <start>
                sampled_caption = ['<start>']

            elif args.model_type == 'no_attention':
                features = encoder(image_tensor)
                sampled_ids = decoder.sample(features)
                sampled_ids = sampled_ids[0].cpu().numpy()
                sampled_caption = ['<start>']

            elif args.model_type == 'transformer':

                e_outputs = model.encoder(image_tensor)
                max_seq_length = 20
                sampled_ids = torch.zeros(max_seq_length, dtype=torch.long)
                sampled_ids[0] = torch.LongTensor([[vocab.word2idx['<start>']]
                                                   ]).to(device)

                for i in range(1, max_seq_length):

                    trg_mask = np.triu(np.ones((1, i, i)), k=1).astype('uint8')
                    trg_mask = Variable(
                        torch.from_numpy(trg_mask) == 0).to(device)

                    out = model.decoder(sampled_ids[:i].unsqueeze(0),
                                        e_outputs, trg_mask)

                    out = model.out(out)
                    out = F.softmax(out, dim=-1)
                    val, ix = out[:, -1].data.topk(1)
                    sampled_ids[i] = ix[0][0]

                sampled_ids = sampled_ids.cpu().numpy()
                sampled_caption = []

            # Convert word_ids to words
            for word_id in sampled_ids:
                word = vocab.idx2word[word_id]
                sampled_caption.append(word)
                # break at <end> of the sentence
                if word == '<end>':
                    break
            sentence = ' '.join(sampled_caption)

            predicted[file] = sentence

        # save predictions to json file:
        json.dump(
            predicted,
            open(
                os.path.join(
                    args.predict_json,
                    'predicted_' + args.mode + '_' + str(epoch) + '.json'),
                'w'))

        #validate model
        with open(args.caption_path_val, 'r') as file:
            captions = json.load(file)

        res = {}
        for r in predicted:
            res[r] = [predicted[r].strip('<start> ').strip(' <end>')]

        images = captions['images']
        caps = captions['annotations']
        gts = {}
        for image in images:
            image_id = image['id']
            file_name = image['file_name']
            list_cap = []
            for cap in caps:
                if cap['image_id'] == image_id:
                    list_cap.append(cap['caption'])
            gts[file_name] = list_cap

        #calculate BLUE, CIDER and ROUGE metrics from real and resulting captions
        bleu_res = bleu(gts, res)
        cider_res = cider(gts, res)
        rouge_res = rouge(gts, res)

        # append resuls to result lists
        bleu_res_list.append(bleu_res)
        cider_res_list.append(cider_res)
        rouge_res_list.append(rouge_res)

        # write results to writer
        writer.add_scalar('BLEU1/validation/epoch', bleu_res[0], epoch + 1)
        writer.add_scalar('BLEU2/validation/epoch', bleu_res[1], epoch + 1)
        writer.add_scalar('BLEU3/validation/epoch', bleu_res[2], epoch + 1)
        writer.add_scalar('BLEU4/validation/epoch', bleu_res[3], epoch + 1)
        writer.add_scalar('CIDEr/validation/epoch', cider_res, epoch + 1)
        writer.add_scalar('ROUGE/validation/epoch', rouge_res, epoch + 1)

    results['bleu'] = bleu_res_list
    results['cider'] = cider_res_list
    results['rouge'] = rouge_res_list

    json.dump(
        results,
        open(os.path.join(args.predict_json, 'results_' + args.mode + '.json'),
             'w'))
    np.save(
        os.path.join(args.predict_json, 'loss_train_' + args.mode + '.npy'),
        loss_train)
    np.save(os.path.join(args.predict_json, 'loss_val_' + args.mode + '.npy'),
            loss_val)
예제 #15
0
class ALADTrainer:
    def __init__(self, args, data, device):
        self.args = args
        self.train_loader, _ = data
        self.device = device
        self.build_models()

    def train(self):
        """Training the ALAD"""

        if self.args.pretrained:
            self.load_weights()

        optimizer_ge = optim.Adam(list(self.G.parameters()) +
                                  list(self.E.parameters()),
                                  lr=self.args.lr,
                                  betas=(0.5, 0.999))
        params_ = list(self.Dxz.parameters()) \
                + list(self.Dzz.parameters()) \
                + list(self.Dxx.parameters())
        optimizer_d = optim.Adam(params_, lr=self.args.lr, betas=(0.5, 0.999))

        fixed_z = Variable(torch.randn((16, self.args.latent_dim, 1, 1)),
                           requires_grad=False).to(self.device)
        criterion = nn.BCELoss()
        for epoch in range(self.args.num_epochs + 1):
            ge_losses = 0
            d_losses = 0
            for x, _ in Bar(self.train_loader):
                #Defining labels
                y_true = Variable(torch.ones((x.size(0), 1)).to(self.device))
                y_fake = Variable(torch.zeros((x.size(0), 1)).to(self.device))

                #Cleaning gradients.
                optimizer_d.zero_grad()
                optimizer_ge.zero_grad()

                #Generator:
                z_real = Variable(torch.randn(
                    (x.size(0), self.args.latent_dim, 1, 1)).to(self.device),
                                  requires_grad=False)
                x_gen = self.G(z_real)

                #Encoder:
                x_real = x.float().to(self.device)
                z_gen = self.E(x_real)

                #Discriminatorxz
                out_truexz, _ = self.Dxz(x_real, z_gen)
                out_fakexz, _ = self.Dxz(x_gen, z_real)

                #Discriminatorzz
                out_truezz, _ = self.Dzz(z_real, z_real)
                out_fakezz, _ = self.Dzz(z_real, self.E(self.G(z_real)))

                #Discriminatorxx
                out_truexx, _ = self.Dxx(x_real, x_real)
                out_fakexx, _ = self.Dxx(x_real, self.G(self.E(x_real)))

                #Losses
                loss_dxz = criterion(out_truexz, y_true) + criterion(
                    out_fakexz, y_fake)
                loss_dzz = criterion(out_truezz, y_true) + criterion(
                    out_fakezz, y_fake)
                loss_dxx = criterion(out_truexx, y_true) + criterion(
                    out_fakexx, y_fake)
                loss_d = loss_dxz + loss_dzz + loss_dxx

                loss_gexz = criterion(out_fakexz, y_true) + criterion(
                    out_truexz, y_fake)
                loss_gezz = criterion(out_fakezz, y_true) + criterion(
                    out_truezz, y_fake)
                loss_gexx = criterion(out_fakexx, y_true) + criterion(
                    out_truexx, y_fake)
                cycle_consistency = loss_gezz + loss_gexx
                loss_ge = loss_gexz + loss_gezz + loss_gexx  # + cycle_consistency
                #Computing gradients and backpropagate.
                loss_d.backward(retain_graph=True)
                loss_ge.backward()
                optimizer_d.step()
                optimizer_ge.step()

                d_losses += loss_d.item()

                ge_losses += loss_ge.item()

            if epoch % 10 == 0:
                vutils.save_image((self.G(fixed_z).data + 1) / 2.,
                                  './images/{}_fake.png'.format(epoch))

            print(
                "Training... Epoch: {}, Discrimiantor Loss: {:.3f}, Generator Loss: {:.3f}"
                .format(epoch, d_losses / len(self.train_loader),
                        ge_losses / len(self.train_loader)))
        self.save_weights()

    def build_models(self):
        self.G = Generator(self.args.latent_dim).to(self.device)
        self.E = Encoder(self.args.latent_dim,
                         self.args.spec_norm).to(self.device)
        self.Dxz = Discriminatorxz(self.args.latent_dim,
                                   self.args.spec_norm).to(self.device)
        self.Dxx = Discriminatorxx(self.args.spec_norm).to(self.device)
        self.Dzz = Discriminatorzz(self.args.latent_dim,
                                   self.args.spec_norm).to(self.device)
        self.G.apply(weights_init_normal)
        self.E.apply(weights_init_normal)
        self.Dxz.apply(weights_init_normal)
        self.Dxx.apply(weights_init_normal)
        self.Dzz.apply(weights_init_normal)

    def save_weights(self):
        """Save weights."""
        state_dict_Dxz = self.Dxz.state_dict()
        state_dict_Dxx = self.Dxx.state_dict()
        state_dict_Dzz = self.Dzz.state_dict()
        state_dict_E = self.E.state_dict()
        state_dict_G = self.G.state_dict()
        torch.save(
            {
                'Generator': state_dict_G,
                'Encoder': state_dict_E,
                'Discriminatorxz': state_dict_Dxz,
                'Discriminatorxx': state_dict_Dxx,
                'Discriminatorzz': state_dict_Dzz
            },
            'weights/model_parameters_{}.pth'.format(self.args.normal_class))

    def load_weights(self):
        """Load weights."""
        state_dict = torch.load('weights/model_parameters.pth')

        self.Dxz.load_state_dict(state_dict['Discriminatorxz'])
        self.Dxx.load_state_dict(state_dict['Discriminatorxx'])
        self.Dzz.load_state_dict(state_dict['Discriminatorzz'])
        self.G.load_state_dict(state_dict['Generator'])
        self.E.load_state_dict(state_dict['Encoder'])
assert torch_pretrained.shape[0] >= total_parameter, 'not enough weight to load'
if torch_pretrained.shape[0] > total_parameter:
    print('Note: fewer parameters then pretrained weights !!!')


# Coping parameters
def copy_params(idx, parameters):
    for p in parameters:
        layer_p_num = np.prod(p.size())
        p.view(-1).copy_(torch.FloatTensor(
            torch_pretrained[idx:idx+layer_p_num]))
        idx += layer_p_num
        print('copy pointer current position: %d' % idx, end='\r', flush=True)
    return idx


print('# of parameters matched, start coping')
idx = 0
if args.encoder:
    idx = copy_params(idx, encoder.parameters())
    torch.save(encoder.state_dict(), args.encoder)
if args.edg_decoder:
    idx = copy_params(idx, edg_decoder.parameters())
    torch.save(edg_decoder.state_dict(), args.edg_decoder)
if args.cor_decoder:
    idx = copy_params(idx, cor_decoder.parameters())
    torch.save(cor_decoder.state_dict(), args.cor_decoder)

print('\nAll thing well done')
def train_dynamics(env, args, writer=None):
    """
    Trains the Dynamics module. Supervised.

    Arguments:
    env: the initialized environment (rllab/gym)
    args: input arguments
    writer: initialized summary writer for tensorboard
    """
    args.action_space = env.action_space

    # Initialize models
    enc = Encoder(env.observation_space.shape[0],
                  args.dim,
                  use_conv=args.use_conv)
    dec = Decoder(env.observation_space.shape[0],
                  args.dim,
                  use_conv=args.use_conv)
    d_module = D_Module(env.action_space.shape[0], args.dim, args.discrete)

    if args.from_checkpoint is not None:
        results_dict = torch.load(args.from_checkpoint)
        enc.load_state_dict(results_dict['enc'])
        dec.load_state_dict(results_dict['dec'])
        d_module.load_state_dict(results_dict['d_module'])

    all_params = chain(enc.parameters(), dec.parameters(),
                       d_module.parameters())

    if args.transfer:
        for p in enc.parameters():
            p.requires_grad = False

        for p in dec.parameters():
            p.requires_grad = False
        all_params = d_module.parameters()

    optimizer = torch.optim.Adam(all_params,
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)

    if args.gpu:
        enc = enc.cuda()
        dec = dec.cuda()
        d_module = d_module.cuda()

    # Initialize datasets
    val_loader = None
    train_dataset = DynamicsDataset(args.train_set,
                                    args.train_size,
                                    batch=args.train_batch,
                                    rollout=args.rollout)
    val_dataset = DynamicsDataset(args.test_set,
                                  5000,
                                  batch=args.test_batch,
                                  rollout=args.rollout)
    val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.num_workers)

    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.num_workers)

    results_dict = {
        'dec_losses': [],
        'forward_losses': [],
        'inverse_losses': [],
        'total_losses': [],
        'enc': None,
        'dec': None,
        'd_module': None,
        'd_init': None,
        'args': args
    }

    total_action_taken = 0
    correct_predicted_a_hat = 0

    # create the mask here for re-weighting
    dec_mask = None
    if args.dec_mask is not None:
        dec_mask = torch.ones(9)
        game_vocab = dict([
            (b, a)
            for a, b in enumerate(sorted(env.game.all_possible_features()))
        ])
        dec_mask[game_vocab['Agent']] = args.dec_mask
        dec_mask[game_vocab['Goal']] = args.dec_mask
        dec_mask = dec_mask.expand(args.batch_size, args.maze_length,
                                   args.maze_length, 9).contiguous().view(-1)
        dec_mask = Variable(dec_mask, requires_grad=False)
        if args.gpu:
            dec_mask = dec_mask.cuda()

    for epoch in range(1, args.num_epochs + 1):
        enc.train()
        dec.train()
        d_module.train()

        if args.framework == "mazebase":
            d_init.train()

        # for measuring the accuracy
        train_acc = 0
        current_epoch_actions = 0
        current_epoch_predicted_a_hat = 0

        start = time.time()
        for i, (states, target_actions) in enumerate(train_loader):

            optimizer.zero_grad()

            if args.framework != "mazebase":
                forward_loss, inv_loss, dec_loss, recon_loss, model_loss, _, _ = forward_planning(
                    i, states, target_actions, enc, dec, d_module, args)
            else:
                forward_loss, inv_loss, dec_loss, recon_loss, model_loss, current_epoch_predicted_a_hat, current_epoch_actions = multiple_forward(
                    i, states, target_actions, enc, dec, d_module, args,
                    d_init, dec_mask)

            loss = forward_loss + args.inv_loss_coef * inv_loss + \
                        args.dec_loss_coef * dec_loss

            if i % args.log_interval == 0:
                log(
                    'Epoch [{}/{}]\tIter [{}/{}]\t'.format(
                        epoch, args.num_epochs, i+1, len(
                        train_dataset)//args.batch_size) + \
                    'Time: {:.2f}\t'.format(time.time() - start) + \
                    'Decoder Loss: {:.2f}\t'.format(dec_loss.data[0]) + \
                    'Forward Loss: {:.2f}\t'.format(forward_loss.data[0] ) + \
                    'Inverse Loss: {:.2f}\t'.format(inv_loss.data[0]) + \
                    'Loss: {:.2f}\t'.format(loss.data[0]))

                results_dict['dec_losses'].append(dec_loss.data[0])
                results_dict['forward_losses'].append(forward_loss.data[0])
                results_dict['inverse_losses'].append(inv_loss.data[0])
                results_dict['total_losses'].append(loss.data[0])

                # write the summaries here
                if writer:
                    writer.add_scalar('dynamics/total_loss', loss.data[0],
                                      epoch)
                    writer.add_scalar('dynamics/decoder', dec_loss.data[0],
                                      epoch)
                    writer.add_scalar('dynamics/reconstruction_loss',
                                      recon_loss.data[0], epoch)
                    writer.add_scalar('dynamics/next_state_prediction_loss',
                                      model_loss.data[0], epoch)
                    writer.add_scalar('dynamics/inv_loss', inv_loss.data[0],
                                      epoch)
                    writer.add_scalar('dynamics/forward_loss',
                                      forward_loss.data[0], epoch)

                    writer.add_scalars(
                        'dynamics/all_losses', {
                            "total_loss": loss.data[0],
                            "reconstruction_loss": recon_loss.data[0],
                            "next_state_prediction_loss": model_loss.data[0],
                            "decoder_loss": dec_loss.data[0],
                            "inv_loss": inv_loss.data[0],
                            "forward_loss": forward_loss.data[0],
                        }, epoch)

            loss.backward()

            correct_predicted_a_hat += current_epoch_predicted_a_hat
            total_action_taken += current_epoch_actions

            # does it not work at all without grad clipping ?
            torch.nn.utils.clip_grad_norm(all_params, args.max_grad_norm)
            optimizer.step()

            # maybe add the generated image to add the logs
            # writer.add_image()

        # Run validation
        if val_loader is not None:
            enc.eval()
            dec.eval()
            d_module.eval()
            forward_loss, inv_loss, dec_loss = 0, 0, 0
            for i, (states, target_actions) in enumerate(val_loader):
                f_loss, i_loss, d_loss, _, _, _, _ = forward_planning(
                    i, states, target_actions, enc, dec, d_module, args)
                forward_loss += f_loss
                inv_loss += i_loss
                dec_loss += d_loss
            loss = forward_loss + args.inv_loss_coef * inv_loss + \
                    args.dec_loss_coef * dec_loss
            if writer:
                writer.add_scalar('val/forward_loss', forward_loss.data[0] / i,
                                  epoch)
                writer.add_scalar('val/inverse_loss', inv_loss.data[0] / i,
                                  epoch)
                writer.add_scalar('val/decoder_loss', dec_loss.data[0] / i,
                                  epoch)
            log(
                '[Validation]\t' + \
                'Decoder Loss: {:.2f}\t'.format(dec_loss.data[0] / i) + \
                'Forward Loss: {:.2f}\t'.format(forward_loss.data[0] / i) + \
                'Inverse Loss: {:.2f}\t'.format(inv_loss.data[0] / i) + \
                'Loss: {:.2f}\t'.format(loss.data[0] / i))
        if epoch % args.checkpoint == 0:
            results_dict['enc'] = enc.state_dict()
            results_dict['dec'] = dec.state_dict()
            results_dict['d_module'] = d_module.state_dict()
            if args.framework == "mazebase":
                results_dict['d_init'] = d_init.state_dict()
            torch.save(
                results_dict,
                os.path.join(args.out, 'dynamics_module_epoch%s.pt' % epoch))
            log('Saved model %s' % epoch)

    results_dict['enc'] = enc.state_dict()
    results_dict['dec'] = dec.state_dict()
    results_dict['d_module'] = d_module.state_dict()
    torch.save(results_dict,
               os.path.join(args.out, 'dynamics_module_epoch%s.pt' % epoch))
    print(os.path.join(args.out, 'dynamics_module_epoch%s.pt' % epoch))
예제 #18
0
class EGBADTrainer:
    def __init__(self, args, data, device):
        self.args = args
        self.train_loader, _ = data
        self.device = device
        self.build_models()

    def train(self):
        """Training the AGBAD"""

        if self.args.pretrained:
            self.load_weights()

        optimizer_ge = optim.Adam(list(self.G.parameters()) +
                                  list(self.E.parameters()),
                                  lr=self.args.lr)
        optimizer_d = optim.Adam(self.D.parameters(), lr=self.args.lr)

        fixed_z = Variable(torch.randn((16, self.args.latent_dim, 1, 1)),
                           requires_grad=False).to(self.device)
        criterion = nn.BCELoss()
        for epoch in range(self.args.num_epochs + 1):
            ge_losses = 0
            d_losses = 0
            for x, _ in Bar(self.train_loader):
                #Defining labels
                y_true = Variable(torch.ones((x.size(0), 1)).to(self.device))
                y_fake = Variable(torch.zeros((x.size(0), 1)).to(self.device))

                #Noise for improving training.
                noise1 = Variable(torch.Tensor(x.size()).normal_(
                    0, 0.1 * (self.args.num_epochs - epoch) /
                    self.args.num_epochs),
                                  requires_grad=False).to(self.device)
                noise2 = Variable(torch.Tensor(x.size()).normal_(
                    0, 0.1 * (self.args.num_epochs - epoch) /
                    self.args.num_epochs),
                                  requires_grad=False).to(self.device)

                #Cleaning gradients.
                optimizer_d.zero_grad()
                optimizer_ge.zero_grad()

                #Generator:
                z_fake = Variable(torch.randn(
                    (x.size(0), self.args.latent_dim, 1, 1)).to(self.device),
                                  requires_grad=False)
                x_fake = self.G(z_fake)

                #Encoder:
                x_true = x.float().to(self.device)
                z_true = self.E(x_true)

                #Discriminator
                out_true = self.D(x_true + noise1, z_true)
                out_fake = self.D(x_fake + noise2, z_fake)

                #Losses
                loss_d = criterion(out_true, y_true) + criterion(
                    out_fake, y_fake)
                loss_ge = criterion(out_fake, y_true) + criterion(
                    out_true, y_fake)

                #Computing gradients and backpropagate.
                loss_d.backward(retain_graph=True)
                optimizer_d.step()

                loss_ge.backward()
                optimizer_ge.step()

                ge_losses += loss_ge.item()
                d_losses += loss_d.item()

            if epoch % 10 == 0:
                vutils.save_image((self.G(fixed_z).data + 1) / 2.,
                                  './images/{}_fake.png'.format(epoch))

            print(
                "Training... Epoch: {}, Discrimiantor Loss: {:.3f}, Generator Loss: {:.3f}"
                .format(epoch, d_losses / len(self.train_loader),
                        ge_losses / len(self.train_loader)))
        self.save_weights()

    def build_models(self):
        self.G = Generator(self.args.latent_dim).to(self.device)
        self.E = Encoder(self.args.latent_dim).to(self.device)
        self.D = Discriminator(self.args.latent_dim).to(self.device)
        self.G.apply(weights_init_normal)
        self.E.apply(weights_init_normal)
        self.D.apply(weights_init_normal)

    def save_weights(self):
        """Save weights."""
        state_dict_D = self.D.state_dict()
        state_dict_E = self.E.state_dict()
        state_dict_G = self.G.state_dict()
        torch.save(
            {
                'Generator': state_dict_G,
                'Encoder': state_dict_E,
                'Discriminator': state_dict_D
            }, 'weights/model_parameters.pth')

    def load_weights(self):
        """Load weights."""
        state_dict = torch.load('weights/model_parameters.pth')

        self.D.load_state_dict(state_dict['Discriminator'])
        self.G.load_state_dict(state_dict['Generator'])
        self.E.load_state_dict(state_dict['Encoder'])
예제 #19
0
                        accuracy_matrix[i][:train_decoder_len_mask[i]] == 0))
            accuracy /= float(sum(train_decoder_len_mask))
            viz.line(X=torch.FloatTensor([step_global]),
                     Y=torch.FloatTensor([accuracy]),
                     win='acc',
                     update=update,
                     opts=opts_acc)

            # 更新loss图
            viz.line(X=torch.FloatTensor([step_global]),
                     Y=torch.FloatTensor([loss]),
                     win='loss',
                     update=update,
                     opts=opts_loss)

    # 每个epoch结束时保存一次模型,覆盖上一次保存,最后只留下最后一次迭代结果
    torch.save(encoder.state_dict(), './save_model/encoder_params_fe.pkl')
    torch.save(decoder.state_dict(), './save_model/decoder_params_fe.pkl')

    # 记录结束时间
    end = time.process_time()
    viz.text(time.strftime("ENDS AT %a %b %d %H:%M:%S %Y \n\n",
                           time.localtime()),
             win='summary',
             append=True)
    m, s = divmod(end - start, 60)
    h, m = divmod(m, 60)
    viz.text("time cost: %02d:%02d:%02d" % (h, m, s),
             win='summary',
             append=True)
예제 #20
0
def train(resume=False):

    it = 0

    writer = SummaryWriter('../runs/' + hparams.exp_name)

    for k in hparams.__dict__.keys():
        writer.add_text(str(k), str(hparams.__dict__[k]))

    train_dataset = ChestData(
        data_csv=hparams.train_csv,
        data_dir=hparams.train_dir,
        transform=transforms.Compose([
            transforms.ToTensor(),
            # transforms.Normalize((0.485), (0.229))
        ]))

    validation_dataset = ChestData(
        data_csv=hparams.valid_csv,
        data_dir=hparams.valid_dir,
        transform=transforms.Compose([
            transforms.ToTensor(),
            # transforms.Normalize((0.485), (0.229))
        ]))

    train_loader = DataLoader(train_dataset,
                              batch_size=hparams.batch_size,
                              shuffle=True,
                              num_workers=0)

    validation_loader = DataLoader(validation_dataset,
                                   batch_size=hparams.batch_size,
                                   shuffle=True,
                                   num_workers=0)

    print('loaded train data of length : {}'.format(len(train_dataset)))

    Tensor = torch.cuda.FloatTensor if hparams.cuda else torch.FloatTensor

    def validation(encoder_, decoder_=None, send_stats=False, epoch=0):
        encoder_ = encoder_.eval()
        if decoder_:
            decoder_ = decoder_.eval()
        # print('Validating model on {0} examples. '.format(len(validation_loader)))
        with torch.no_grad():
            scores_list = []
            labels_list = []
            val_loss = 0
            for (img, labels, imgs_names) in validation_loader:
                img = Variable(img.float(), requires_grad=False)
                labels = Variable(labels.float(), requires_grad=False)
                scores = None
                if hparams.cuda:
                    img = img.cuda(hparams.gpu_device)
                    labels = labels.cuda(hparams.gpu_device)

                z = encoder_(img)

                if decoder_:
                    outputs = decoder_(z)
                    scores = torch.sum(
                        (outputs - img)**2, dim=tuple(range(
                            1, outputs.dim())))  # (outputs - img) ** 2
                    # rec_loss = rec_loss.view(outputs.shape[0], -1)
                    # rec_loss = torch.sum(torch.sum(rec_loss, dim=1))
                    val_loss += torch.sum(scores)
                    save_image(img,
                               'tmp/img_{}.png'.format(epoch),
                               normalize=True)
                    save_image(outputs,
                               'tmp/reconstructed_{}.png'.format(epoch),
                               normalize=True)

                else:
                    dist = torch.sum((z - encoder.center)**2, dim=1)
                    if hparams.objective == 'soft-boundary':
                        scores = dist - encoder.radius**2
                        val_loss += (1 / hparams.nu) * torch.sum(
                            torch.max(torch.zeros_like(scores), scores))
                    else:
                        scores = dist
                        val_loss += torch.sum(dist)

                scores_list.append(scores)
                labels_list.append(labels)

            scores = torch.cat(scores_list, dim=0)
            labels = torch.cat(labels_list, dim=0)

            val_loss /= len(validation_dataset)
            val_loss += encoder_.radius**2 if decoder_ and hparams.objective == 'soft-boundary' else 0

            if hparams.cuda:
                labels = labels.cpu()
                scores = scores.cpu()

            labels = labels.view(-1).numpy()
            scores = scores.view(-1).detach().numpy()

            auc = roc_auc_score(labels, scores)

        return auc, val_loss

    ### validation function ends.

    if hparams.cuda:
        encoder = Encoder().cuda(hparams.gpu_device)
        decoder = Decoder().cuda(hparams.gpu_device)
    else:
        encoder = Encoder()
        decoder = Decoder()

    params_count = 0
    for param in encoder.parameters():
        params_count += np.prod(param.size())
    for param in decoder.parameters():
        params_count += np.prod(param.size())
    print('Model has {0} trainable parameters'.format(params_count))

    if not hparams.load_model:
        encoder.apply(weights_init_normal)
        decoder.apply(weights_init_normal)

    optim_params = list(encoder.parameters())
    optimizer_train = optim.Adam(optim_params,
                                 lr=hparams.train_lr,
                                 weight_decay=hparams.weight_decay,
                                 amsgrad=hparams.optimizer == 'amsgrad')

    if hparams.pretrain:
        optim_params += list(decoder.parameters())
        optimizer_pre = optim.Adam(optim_params,
                                   lr=hparams.pretrain_lr,
                                   weight_decay=hparams.ae_weight_decay,
                                   amsgrad=hparams.optimizer == 'amsgrad')
        # scheduler_pre = ReduceLROnPlateau(optimizer_pre, mode='min', factor=0.5, patience=10, verbose=True, cooldown=20)
        scheduler_pre = MultiStepLR(optimizer_pre,
                                    milestones=hparams.lr_milestones,
                                    gamma=0.1)

    # scheduler_train = ReduceLROnPlateau(optimizer_train, mode='min', factor=0.5, patience=10, verbose=True, cooldown=20)
    scheduler_train = MultiStepLR(optimizer_train,
                                  milestones=hparams.lr_milestones,
                                  gamma=0.1)

    print('Starting training.. (log saved in:{})'.format(hparams.exp_name))
    start_time = time.time()

    mode = 'pretrain' if hparams.pretrain else 'train'
    best_valid_loss = 100000000000000000
    best_valid_auc = 0
    encoder = init_center(encoder, train_loader)

    # print(model)
    for epoch in range(hparams.num_epochs):
        if mode == 'pretrain' and epoch == hparams.pretrain_epoch:
            print('Pretraining done.')
            mode = 'train'
            best_valid_loss = 100000000000000000
            best_valid_auc = 0
            encoder = init_center(encoder, train_loader)
        for batch, (imgs, labels, _) in enumerate(train_loader):

            # imgs = Variable(imgs.float(), requires_grad=False)

            if hparams.cuda:
                imgs = imgs.cuda(hparams.gpu_device)

            if mode == 'pretrain':
                optimizer_pre.zero_grad()
                z = encoder(imgs)
                outputs = decoder(z)
                # print(torch.max(outputs), torch.mean(imgs), torch.min(outputs), torch.mean(imgs))
                scores = torch.sum((outputs - imgs)**2,
                                   dim=tuple(range(1, outputs.dim())))
                # print(scores)
                loss = torch.mean(scores)
                loss.backward()
                optimizer_pre.step()
                writer.add_scalar('pretrain_loss',
                                  loss.item(),
                                  global_step=batch +
                                  len(train_loader) * epoch)

            else:
                optimizer_train.zero_grad()

                z = encoder(imgs)
                dist = torch.sum((z - encoder.center)**2, dim=1)
                if hparams.objective == 'soft-boundary':
                    scores = dist - encoder.radius**2
                    loss = encoder.radius**2 + (1 / hparams.nu) * torch.mean(
                        torch.max(torch.zeros_like(scores), scores))
                else:
                    loss = torch.mean(dist)

                loss.backward()
                optimizer_train.step()

                if hparams.objective == 'soft-boundary' and epoch >= hparams.warmup_epochs:
                    R = np.quantile(np.sqrt(dist.clone().data.cpu().numpy()),
                                    1 - hparams.nu)
                    encoder.radius = torch.tensor(R)
                    if hparams.cuda:
                        encoder.radius = encoder.radius.cuda(
                            hparams.gpu_device)
                    writer.add_scalar('radius',
                                      encoder.radius.item(),
                                      global_step=batch +
                                      len(train_loader) * epoch)
                writer.add_scalar('train_loss',
                                  loss.item(),
                                  global_step=batch +
                                  len(train_loader) * epoch)

            # pred_labels = (scores >= hparams.thresh)

            # save_image(imgs, 'train_imgs.png')
            # save_image(noisy_imgs, 'train_noisy.png')
            # save_image(gen_imgs, 'train_z.png')

            if batch % hparams.print_interval == 0:
                print('[Epoch - {0:.1f}, batch - {1:.3f}, loss - {2:.6f}]'.\
                format(1.0*epoch, 100.0*batch/len(train_loader), loss.item()))

        if mode == 'pretrain':
            val_auc, rec_loss = validation(copy.deepcopy(encoder),
                                           copy.deepcopy(decoder),
                                           epoch=epoch)
        else:
            val_auc, val_loss = validation(copy.deepcopy(encoder), epoch=epoch)

        writer.add_scalar('val_auc', val_auc, global_step=epoch)

        if mode == 'pretrain':
            best_valid_auc = max(best_valid_auc, val_auc)
            scheduler_pre.step()
            writer.add_scalar('rec_loss', rec_loss, global_step=epoch)
            writer.add_scalar('pretrain_lr',
                              optimizer_pre.param_groups[0]['lr'],
                              global_step=epoch)
            torch.save(
                {
                    'epoch': epoch,
                    'encoder_state_dict': encoder.state_dict(),
                    'decoder_state_dict': decoder.state_dict(),
                    'optimizer_pre_state_dict': optimizer_pre.state_dict(),
                }, hparams.model + '.pre')
            if best_valid_loss >= rec_loss:
                best_valid_loss = rec_loss
                torch.save(
                    {
                        'epoch': epoch,
                        'encoder_state_dict': encoder.state_dict(),
                        'decoder_state_dict': decoder.state_dict(),
                        'optimizer_pre_state_dict': optimizer_pre.state_dict(),
                    }, hparams.model + '.pre.best')
                print('best model on validation set saved.')
            print('[Epoch - {0:.1f} ---> rec_loss - {1:.4f}, current_lr - {2:.6f}, val_auc - {3:.4f}, best_valid_auc - {4:.4f}] - time - {5:.1f}'\
                .format(1.0*epoch, rec_loss, optimizer_pre.param_groups[0]['lr'], val_auc, best_valid_auc, time.time()-start_time))

        else:
            scheduler_train.step()
            writer.add_scalar('val_loss', val_loss, global_step=epoch)
            writer.add_scalar('train_lr',
                              optimizer_train.param_groups[0]['lr'],
                              global_step=epoch)
            torch.save(
                {
                    'epoch': epoch,
                    'encoder_state_dict': encoder.state_dict(),
                    'center': encoder.center,
                    'radius': encoder.radius,
                    'optimizer_train_state_dict': optimizer_train.state_dict(),
                }, hparams.model + '.train')
            if best_valid_loss >= val_loss:
                best_valid_loss = val_loss
                torch.save(
                    {
                        'epoch': epoch,
                        'encoder_state_dict': encoder.state_dict(),
                        'center': encoder.center,
                        'radius': encoder.radius,
                        'optimizer_train_state_dict':
                        optimizer_train.state_dict(),
                    }, hparams.model + '.train.best')
                print('best model on validation set saved.')
            if best_valid_auc <= val_auc:
                best_valid_auc = val_auc
                torch.save(
                    {
                        'epoch': epoch,
                        'encoder_state_dict': encoder.state_dict(),
                        'center': encoder.center,
                        'radius': encoder.radius,
                        'optimizer_train_state_dict':
                        optimizer_train.state_dict(),
                    }, hparams.model + '.train.auc')
                print('best model on validation set saved.')
            print('[Epoch - {0:.1f} ---> val_loss - {1:.4f}, current_lr - {2:.6f}, val_auc - {3:.4f}, best_valid_auc - {4:.4f}] - time - {5:.1f}'\
                .format(1.0*epoch, val_loss, optimizer_train.param_groups[0]['lr'], val_auc, best_valid_auc, time.time()-start_time))

        start_time = time.time()
예제 #21
0
def main():
    # Create model directory
    ##### arguments #####
    PATH = os.getcwd()
    image_dir = './data/resized2014/'
    caption_path = './data/annotations/captions_train2014.json'
    vocab_path = './data/vocab.pkl'
    model_path = './model'
    crop_size = 224
    batch_size = 128
    num_workers = 4
    learning_rate = 0.001

    # Decoder
    embed_size = 512
    hidden_size = 512
    num_layers = 3  # number of lstm layers
    num_epochs = 10
    start_epoch = 0
    save_step = 3000

    if not os.path.exists(model_path):
        os.makedirs(model_path)

    # Image preprocessing, normalization for the pretrained resnet
    transform = transforms.Compose([
        transforms.RandomCrop(crop_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    # Load vocabulary wrapper
    with open(vocab_path, 'rb') as f:
        vocab = pickle.load(f)

    # Build data loader
    coco = CocoDataset(image_dir, caption_path, vocab, transform)
    dataLoader = torch.utils.data.DataLoader(coco,
                                             batch_size,
                                             shuffle=True,
                                             num_workers=4,
                                             collate_fn=coco_batch)

    # Declare the encoder decoder
    encoder = Encoder(embed_size=embed_size).to(device)
    decoder = Decoder(embed_size=embed_size,
                      hidden_size=hidden_size,
                      vocab_size=len(vocab),
                      num_layers=num_layers).to(device)

    encoder.train()
    decoder.train()
    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    params = list(decoder.parameters()) + list(encoder.resnet.fc.parameters())
    # For encoder only train the last fc layer
    optimizer = torch.optim.Adam(params, lr=learning_rate)

    # Train the models
    total_step = len(dataLoader)
    for epoch in range(num_epochs):
        for i, (images, captions, lengths) in enumerate(dataLoader):
            # Set mini-batch dataset
            images = images.cuda()
            captions = captions.cuda()
            targets = pack_padded_sequence(captions, lengths,
                                           batch_first=True)[0]

            # Forward, backward and optimize
            features = encoder(images)
            outputs = decoder(features, captions, lengths)
            loss = criterion(outputs, targets)
            decoder.zero_grad()
            encoder.zero_grad()

            for group in optimizer.param_groups:
                for p in group['params']:
                    state = optimizer.state[p]
                    if ('step' in state and state['step'] >= 1024):
                        state['step'] = 1000

            loss.backward(retain_graph=True)
            optimizer.step()

            # Print log info
            if i % 100 == 0:
                print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(
                    epoch + 1 + start_epoch, num_epochs + start_epoch, i,
                    total_step, loss.item()))

            # Save the model checkpoints
            if (i + 1) % save_step == 0:
                torch.save(
                    decoder.state_dict(),
                    os.path.join(
                        model_path,
                        'decoder-{}-{}.ckpt'.format(epoch + 1 + start_epoch,
                                                    i + 1)))
                torch.save(
                    encoder.state_dict(),
                    os.path.join(
                        model_path,
                        'encoder-{}-{}.ckpt'.format(epoch + 1 + start_epoch,
                                                    i + 1)))

        print('epoch ', epoch + 1, 'loss: ', loss.item())
예제 #22
0
파일: train.py 프로젝트: kondo-kk/flu_trend
def train(region):
    np.random.seed(0)
    torch.manual_seed(0)

    input_len = 10
    encoder_units = 32
    decoder_units = 64
    encoder_rnn_layers = 3
    encoder_dropout = 0.2
    decoder_dropout = 0.2
    input_size = 2
    output_size = 1
    predict_len = 5
    batch_size = 16
    epochs = 500
    force_teacher = 0.8

    train_dataset, test_dataset, train_max, train_min = create_dataset(
        input_len, predict_len, region)
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
    test_loader = DataLoader(
        test_dataset, batch_size=batch_size, shuffle=False, drop_last=True)

    enc = Encoder(input_size, encoder_units, input_len,
                  encoder_rnn_layers, encoder_dropout)
    dec = Decoder(encoder_units*2, decoder_units, input_len,
                  input_len, decoder_dropout, output_size)

    optimizer = AdaBound(list(enc.parameters()) +
                         list(dec.parameters()), 0.01, final_lr=0.1)
    # optimizer = optim.Adam(list(enc.parameters()) + list(dec.parameters()), 0.01)
    criterion = nn.MSELoss()

    mb = master_bar(range(epochs))
    for ep in mb:
        train_loss = 0
        enc.train()
        dec.train()
        for encoder_input, decoder_input, target in progress_bar(train_loader, parent=mb):
            optimizer.zero_grad()
            enc_vec = enc(encoder_input)
            h = enc_vec[:, -1, :]
            _, c = dec.initHidden(batch_size)
            x = decoder_input[:, 0]
            pred = []
            for pi in range(predict_len):
                x, h, c = dec(x, h, c, enc_vec)
                rand = np.random.random()
                pred += [x]
                if rand < force_teacher:
                    x = decoder_input[:, pi]
            pred = torch.cat(pred, dim=1)
            # loss = quantile_loss(pred, target)
            loss = criterion(pred, target)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        test_loss = 0
        enc.eval()
        dec.eval()
        for encoder_input, decoder_input, target in progress_bar(test_loader, parent=mb):
            with torch.no_grad():
                enc_vec = enc(encoder_input)
                h = enc_vec[:, -1, :]
                _, c = dec.initHidden(batch_size)
                x = decoder_input[:, 0]
                pred = []
                for pi in range(predict_len):
                    x, h, c = dec(x, h, c, enc_vec)
                    pred += [x]
                pred = torch.cat(pred, dim=1)
            # loss = quantile_loss(pred, target)
            loss = criterion(pred, target)
            test_loss += loss.item()
        print(
            f"Epoch {ep} Train Loss {train_loss/len(train_loader)} Test Loss {test_loss/len(test_loader)}")

    if not os.path.exists("models"):
        os.mkdir("models")
    torch.save(enc.state_dict(), f"models/{region}_enc.pth")
    torch.save(dec.state_dict(), f"models/{region}_dec.pth")

    test_loader = DataLoader(test_dataset, batch_size=1,
                             shuffle=False, drop_last=False)

    rmse = 0
    p = 0
    predicted = []
    true_target = []
    enc.eval()
    dec.eval()
    for encoder_input, decoder_input, target in progress_bar(test_loader, parent=mb):
        with torch.no_grad():
            enc_vec = enc(encoder_input)
            x = decoder_input[:, 0]
            h, c = dec.initHidden(1)
            pred = []
            for pi in range(predict_len):
                x, h, c = dec(x, h, c, enc_vec)
                pred += [x]
            pred = torch.cat(pred, dim=1)
            predicted += [pred[0, p].item()]
            true_target += [target[0, p].item()]
    predicted = np.array(predicted).reshape(1, -1)
    predicted = predicted * (train_max - train_min) + train_min
    true_target = np.array(true_target).reshape(1, -1)
    true_target = true_target * (train_max - train_min) + train_min
    rmse, peasonr = calc_metric(predicted, true_target)
    print(f"{region} RMSE {rmse}")
    print(f"{region} r {peasonr[0]}")
    return f"{region} RMSE {rmse} r {peasonr[0]}"
예제 #23
0
if __name__ == '__main__':

    if len(sys.argv) != 4:
        print('Usage: python predict.py [image.npy] [test_case.csv] [ans.csv]')
        exit(0)
    else:
        fp_data = sys.argv[1]
        fp_ind = sys.argv[2]
        fp_ans = sys.argv[3]

fp_model_fe = 'model6.fe.pt'

state_dict = torch.load(fp_model_fe)

model_enc = Encoder()
model_enc_dict = model_enc.state_dict()
model_enc_dict.update({k: v for k, v in state_dict.items() \
                            if k in model_enc_dict})
model_enc.load_state_dict(model_enc_dict)
model_enc.cuda()

test_loader = load_data(fp_data)
features = predict(model_enc, test_loader)

ind = (pd.read_csv(fp_ind, delimiter=',').values)[:, 1:]

pred = []
for i in range(ind.shape[0]):
    if np.linalg.norm(features[ind[i][0]] - features[ind[i][1]]) > 10:
        pred.append(0)
    else:
예제 #24
0
파일: train.py 프로젝트: v-juma1/textent
def train(description_db, entity_db, word_vocab, entity_vocab,
          target_entity_vocab, out_file, embeddings, dim_size, batch_size,
          negative, epoch, optimizer, max_text_len, max_entity_len, pool_size,
          seed, save, **model_params):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    word_matrix = np.random.uniform(low=-0.05,
                                    high=0.05,
                                    size=(word_vocab.size, dim_size))
    word_matrix = np.vstack([np.zeros(dim_size),
                             word_matrix]).astype('float32')

    entity_matrix = np.random.uniform(low=-0.05,
                                      high=0.05,
                                      size=(entity_vocab.size, dim_size))
    entity_matrix = np.vstack([np.zeros(dim_size),
                               entity_matrix]).astype('float32')

    target_entity_matrix = np.random.uniform(low=-0.05,
                                             high=0.05,
                                             size=(target_entity_vocab.size,
                                                   dim_size))
    target_entity_matrix = np.vstack(
        [np.zeros(dim_size), target_entity_matrix]).astype('float32')

    for embedding in embeddings:
        for word in word_vocab:
            vec = embedding.get_word_vector(word)
            if vec is not None:
                word_matrix[word_vocab.get_index(word)] = vec

        for title in entity_vocab:
            vec = embedding.get_entity_vector(title)
            if vec is not None:
                entity_matrix[entity_vocab.get_index(title)] = vec

        for title in target_entity_vocab:
            vec = embedding.get_entity_vector(title)
            if vec is not None:
                target_entity_matrix[target_entity_vocab.get_index(
                    title)] = vec

    entity_negatives = np.arange(1, target_entity_matrix.shape[0])

    model_params.update(dict(dim_size=dim_size))
    model = Encoder(word_embedding=word_matrix,
                    entity_embedding=entity_matrix,
                    target_entity_embedding=target_entity_matrix,
                    word_vocab=word_vocab,
                    entity_vocab=entity_vocab,
                    target_entity_vocab=target_entity_vocab,
                    **model_params)

    del word_matrix
    del entity_matrix
    del target_entity_matrix

    model = model.cuda()

    model.train()
    parameters = [p for p in model.parameters() if p.requires_grad]
    optimizer_ins = getattr(optim, optimizer)(parameters)

    n_correct = 0
    n_total = 0
    cur_correct = 0
    cur_total = 0
    cur_loss = 0.0

    batch_idx = 0

    joblib.dump(
        dict(model_params=model_params,
             word_vocab=word_vocab.serialize(),
             entity_vocab=entity_vocab.serialize(),
             target_entity_vocab=target_entity_vocab.serialize()),
        out_file + '.pkl')

    if not save or 0 in save:
        state_dict = model.state_dict()
        torch.save(state_dict, out_file + '_epoch0.bin')

    for n_epoch in range(1, epoch + 1):
        logger.info('Epoch: %d', n_epoch)

        for (batch_idx, (args, target)) in enumerate(
                generate_data(description_db, word_vocab, entity_vocab,
                              target_entity_vocab, entity_negatives,
                              batch_size, negative, max_text_len,
                              max_entity_len, pool_size), batch_idx):
            args = tuple([o.cuda(async=True) for o in args])
            target = target.cuda()

            optimizer_ins.zero_grad()
            output = model(args)
            loss = F.cross_entropy(output, target)
            loss.backward()

            optimizer_ins.step()

            cur_correct += (torch.max(output, 1)[1].view(
                target.size()).data == target.data).sum()
            cur_total += len(target)
            cur_loss += loss.data
            if batch_idx != 0 and batch_idx % 1000 == 0:
                n_correct += cur_correct
                n_total += cur_total
                logger.info(
                    'Processed %d batches (epoch: %d, loss: %.4f acc: %.4f total acc: %.4f)'
                    % (batch_idx, n_epoch, cur_loss[0] / cur_total, 100. *
                       cur_correct / cur_total, 100. * n_correct / n_total))
                cur_correct = 0
                cur_total = 0
                cur_loss = 0.0
예제 #25
0
파일: main.py 프로젝트: mandiehyewon/AI502
		#		%(step+1, total_step, loss.data[0], accuracy.data[0]))

		#============ TensorBoard logging ============#
		# (1) Log the scalar values
		info = {
			'recon_loss': recon_loss.data[0],
			'discriminator_loss': D_loss.data[0],
			'generator_loss': G_loss.data[0]
		}

		for tag, value in info.items():
			logger.scalar_summary(tag, value, step+1)

		# (2) Log values and gradients of the parameters (histogram)
		for net,name in zip([Enc,Dec,Discrim],['Encoder','Decoder','Discrim']): 
			for tag, value in net.named_parameters():
				tag = name+tag.replace('.', '/')
				logger.histo_summary(tag, to_np(value), step+1)
				logger.histo_summary(tag+'/grad', to_np(value.grad), step+1)

		# (3) Log the images
		info = {
			'images': to_np(images.view(-1, 28, 28)[:10])
		}

		for tag, images in info.items():
			logger.image_summary(tag, images, step+1)

#save the Encoder
torch.save(Enc.state_dict(),'Encoder_weights.pt')
예제 #26
0
    test_score = []
    valid_score = []
    for epoch in range(start, num_epochs):
        if (epoch == 45):
            if (opt.use_decay_learning and not opt.use_linearly_decay):
                print('decrease learning rate to half', flush=True)
                half_adjust_learning_rate(optimizerD, epoch, num_epochs)
                half_adjust_learning_rate(optimizerD2, epoch, num_epochs)
                half_adjust_learning_rate(optimizerG, epoch, num_epochs)
                half_adjust_learning_rate(optimizerE, epoch, num_epochs)
            else:
                print('still use 2e-4 for trainning')

            torch.save(netG.state_dict(), '%s/net45G.pth' % (saveModelRoot))
            torch.save(netD.state_dict(), '%s/net45D.pth' % (saveModelRoot))
            torch.save(netE.state_dict(), '%s/net45E.pth' % (saveModelRoot))
            torch.save(netD2.state_dict(), '%s/net45D2.pth' % (saveModelRoot))

        if (epoch == 60):
            print('save 60 epochs models ')

            torch.save(netG.state_dict(), '%s/net60G.pth' % (saveModelRoot))
            torch.save(netD.state_dict(), '%s/net60D.pth' % (saveModelRoot))
            torch.save(netE.state_dict(), '%s/net60E.pth' % (saveModelRoot))
            torch.save(netD2.state_dict(), '%s/net60D2.pth' % (saveModelRoot))

        if (epoch == 70):
            print('save 70 epochs models ')

            torch.save(netG.state_dict(), '%s/net70G.pth' % (saveModelRoot))
            torch.save(netD.state_dict(), '%s/net70D.pth' % (saveModelRoot))
예제 #27
0
    optimizer = torch.optim.Adam(params, lr=0.001)

    total_step = len(dataloader)
    for epoch in range(5):
        for i, (images, captions, lengths) in enumerate(dataloader):
            images = to_var(images, volatile=True)
            captions = to_var(captions)
            targets = pack_padded_sequence(captions, lengths,
                                           batch_first=True)[0]
            decoder.zero_grad()
            encoder.zero_grad()
            features = encoder(images)

            outputs = decoder(features, captions, lengths)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            if i % 10 == 0:
                print("Epoch {} step {}, Loss: {}, Perplexity: {}".format(
                    epoch, i, loss.data[0], np.exp(loss.data[0])))
            if (i + 1) % 1000 == 0:
                torch.save(
                    decoder.state_dict(),
                    os.path.join("../output/",
                                 'decoder-{}-{}.pkl'.format(epoch, i + 1)))
                torch.save(
                    encoder.state_dict(),
                    os.path.join("../output/",
                                 'encoder-{}-{}.pkl'.format(epoch, i + 1)))
        print "MACRO prec %.2f%%, recall %.2f%%, f1 %.2f%%" % (
            recall * 100, prec * 100, macro_f1 * 100)

        prec, recall, f1, _ = precision_recall_fscore_support(preds_all,
                                                              labels_all,
                                                              average='micro')
        print "MICRO prec %.2f%%, recall %.2f%%, f1 %.2f%%" % (
            recall * 100, prec * 100, f1 * 100)

        if args.test_model:
            break
        else:
            writer.add_scalar('micro_f1', f1 * 100, epoch)
            writer.add_scalar('macro_f1', macro_f1 * 100, epoch)
            save_dict = {
                'encoder_state_dict': encoder.state_dict(),
                'decoder_state_dict': decoder.state_dict(),
                'resume': epoch,
                'f1': f1,
                'iterations': iterations,
                'decoder_optimizer_state_dict': decoder_optimizer.state_dict(),
                'epochs_without_imp': epochs_without_imp
            }
            if swa_params:
                save_dict['decoder_swa_state_dict'] = decoder_swa.state_dict()
            if finetune_encoder:
                save_dict[
                    'encoder_optimizer_state_dict'] = encoder_optimizer.state_dict(
                    )
                if swa_params:
                    save_dict[
예제 #29
0
파일: train.py 프로젝트: bosung/cQA
def main(args):
    global batch_size
    batch_size = args.batch_size
    hidden_size = args.hidden_size
    w_embed_size = args.w_embed_size
    lr = args.lr

    train_file = 'data/train_data_nv.txt'

    vocab = Vocab()
    vocab.build(train_file)

    if args.pre_trained_embed == 'n':
        encoder = Encoder(vocab.n_words, w_embed_size, hidden_size,
                          batch_size).to(device)
        decoder = AttentionDecoder(vocab.n_words, w_embed_size, hidden_size,
                                   batch_size).to(device)
    else:
        # load pre-trained embedding
        weight = vocab.load_weight(path="data/komoran_hd_2times.vec")
        encoder = Encoder(vocab.n_words, w_embed_size, hidden_size, batch_size,
                          weight).to(device)
        decoder = AttentionDecoder(vocab.n_words, w_embed_size, hidden_size,
                                   batch_size, weight).to(device)

    if args.encoder:
        encoder.load_state_dict(torch.load(args.encoder))
        print("[INFO] load encoder with %s" % args.encoder)
    if args.decoder:
        decoder.load_state_dict(torch.load(args.decoder))
        print("[INFO] load decoder with %s" % args.decoder)

    train_data = prep.read_train_data(train_file)
    train_loader = data.DataLoader(train_data,
                                   batch_size=batch_size,
                                   shuffle=True)

    # ev.evaluateRandomly(encoder, decoder, train_data, vocab, batch_size)
    # ev.evaluate_with_print(encoder, vocab, batch_size)

    # initialize
    max_a_at_5, max_a_at_1 = ev.evaluate_similarity(encoder,
                                                    vocab,
                                                    batch_size,
                                                    decoder=decoder)
    # max_a_at_5, max_a_at_1 = 0, 0
    max_bleu = 0

    total_epoch = args.epoch
    print(args)
    for epoch in range(1, total_epoch + 1):
        random.shuffle(train_data)
        trainIters(args,
                   epoch,
                   encoder,
                   decoder,
                   total_epoch,
                   train_data,
                   vocab,
                   train_loader,
                   print_every=2,
                   learning_rate=lr)

        if epoch % 20 == 0:
            a_at_5, a_at_1 = ev.evaluate_similarity(encoder,
                                                    vocab,
                                                    batch_size,
                                                    decoder=decoder)

            if a_at_1 > max_a_at_1:
                max_a_at_1 = a_at_1
                print("[INFO] New record! accuracy@1: %.4f" % a_at_1)

            if a_at_5 > max_a_at_5:
                max_a_at_5 = a_at_5
                print("[INFO] New record! accuracy@5: %.4f" % a_at_5)
                if args.save == 'y':
                    torch.save(encoder.state_dict(), 'encoder-max.model')
                    torch.save(decoder.state_dict(), 'decoder-max.model')
                    print("[INFO] new model saved")

            bleu = ev.evaluateRandomly(encoder, decoder, train_data, vocab,
                                       batch_size)
            if bleu > max_bleu:
                max_bleu = bleu
                if args.save == 'y':
                    torch.save(encoder.state_dict(), 'encoder-max-bleu.model')
                    torch.save(decoder.state_dict(), 'decoder-max-bleu.model')
                    print("[INFO] new model saved")

    print("Done! max accuracy@5: %.4f, max accuracy@1: %.4f" %
          (max_a_at_5, max_a_at_1))
    print("max bleu: %.2f" % max_bleu)
    if args.save == 'y':
        torch.save(encoder.state_dict(), 'encoder-last.model')
        torch.save(decoder.state_dict(), 'decoder-last.model')
예제 #30
0
                                                    batch_size,
                                                    decoder=decoder)

            if a_at_1 > max_a_at_1:
                max_a_at_1 = a_at_1
                print("[INFO] New record! accuracy@1: %.4f" % a_at_1)
                # if args.save == 'y':
                #    torch.save(encoder.state_dict(), 'encoder-max.model')
                #    torch.save(decoder.state_dict(), 'decoder-max.model')
                #    print("[INFO] new model saved")

            if a_at_5 > max_a_at_5:
                max_a_at_5 = a_at_5
                print("[INFO] New record! accuracy@5: %.4f" % a_at_5)

            bleu = ev.evaluateRandomly(encoder, decoder, train_data, vocab,
                                       batch_size)
            if bleu > max_bleu:
                max_bleu = bleu
                if args.save == 'y':
                    torch.save(encoder.state_dict(), 'encoder-max.model')
                    torch.save(decoder.state_dict(), 'decoder-max.model')
                    print("[INFO] new model saved")

    print("Done! max accuracy@5: %.4f, max accuracy@1: %.4f" %
          (max_a_at_5, max_a_at_1))
    print("max bleu: %.2f" % max_bleu)
    if args.save == 'y':
        torch.save(encoder.state_dict(), 'encoder-last.model')
        torch.save(decoder.state_dict(), 'decoder-last.model')