def main():
    epochs = 301
    seq_batch_size = 200
    print_yes = 100
    iscuda = False

    # create our network, optimizer and loss function
    net = RNN(len(chars), 100, 150, 2)  #instanciate a RNN object
    optim = torch.optim.Adam(net.parameters(), lr=6e-4)
    loss_func = torch.nn.functional.nll_loss

    if iscuda:
        net = net.cuda()

    # main training loop:
    for epoch in range(epochs):
        dat = getSequence(book, seq_batch_size)
        dat = torch.LongTensor(
            [chars.find(item) for item in dat]
        )  #find corresponding char index for each character and store this in tensor

        # pull x, y and initialize hidden state
        if iscuda:
            x_t = dat[:-1].cuda()
            y_t = dat[1:].cuda()
            hidden = net.init_hidden().cuda()
        else:
            x_t = dat[:-1]
            y_t = dat[1:]
            hidden = net.init_hidden()

        # forward pass
        logprob, hidden = net.forward(x_t, hidden)
        loss = loss_func(logprob, y_t)
        # update
        optim.zero_grad()
        loss.backward()
        optim.step()
        # print the loss for every kth iteration
        if epoch % print_yes == 0:
            print('*' * 60)
            print('\n epoch {}, loss:{} \n'.format(epoch, loss))
            print('sample speech:\n', test_words(net, chars, seq_batch_size))

    torch.save(net.state_dict(), 'trainedBook_v2.pt')
Example #2
0
    if args.optimizer == 'SGD_LR_SCHEDULE':
        lr_decay = lr_decay_base**max(epoch - m_flat_lr, 0)
        lr = lr * lr_decay  # decay lr if it is time

    # RUN MODEL ON TRAINING DATA
    train_ppl, train_loss = run_epoch(model, train_data, True, lr)

    # RUN MODEL ON VALIDATION DATA
    val_ppl, val_loss = run_epoch(model, valid_data)

    # SAVE MODEL IF IT'S THE BEST SO FAR
    if val_ppl < best_val_so_far:
        best_val_so_far = val_ppl
        if args.save_best:
            print("Saving model parameters to best_params.pt")
            torch.save(model.state_dict(),
                       os.path.join(args.save_dir, 'best_params.pt'))
        # NOTE ==============================================
        # You will need to load these parameters into the same model
        # for a couple Problems: so that you can compute the gradient
        # of the loss w.r.t. hidden state as required in Problem 5.2
        # and to sample from the the model as required in Problem 5.3
        # We are not asking you to run on the test data, but if you
        # want to look at test performance you would load the saved
        # model and run on the test data with batch_size=1

    # LOC RESULTS
    train_ppls.append(train_ppl)
    val_ppls.append(val_ppl)
    train_losses.extend(train_loss)
    val_losses.extend(val_loss)
Example #3
0
        lr_decay = lr_decay_base ** max(epoch - m_flat_lr, 0)
        lr = lr * lr_decay # decay lr if it is time

    # RUN MODEL ON TRAINING DATA
    train_ppl, train_loss = run_epoch(model, train_data, True, lr)

    # RUN MODEL ON VALIDATION DATA
    val_ppl, val_loss = run_epoch(model, valid_data)


    # SAVE MODEL IF IT'S THE BEST SO FAR
    if val_ppl < best_val_so_far:
        best_val_so_far = val_ppl
        if args.save_best:
            print("Saving model parameters to best_params.pt")
            torch.save(model.state_dict(), os.path.join(args.save_dir, 'best_params.pt'))
        # NOTE ==============================================
        # You will need to load these parameters into the same model
        # for a couple Problems: so that you can compute the gradient 
        # of the loss w.r.t. hidden state as required in Problem 5.2
        # and to sample from the the model as required in Problem 5.3
        # We are not asking you to run on the test data, but if you 
        # want to look at test performance you would load the saved
        # model and run on the test data with batch_size=1

    # LOC RESULTS
    train_ppls.append(train_ppl)
    val_ppls.append(val_ppl)
    train_losses.extend(train_loss)
    val_losses.extend(val_loss)
    times.append(time.time() - t0)
Example #4
0
for epoch in range(args.num_epochs):
    model.train()
    for i, data in enumerate(tqdm.tqdm(train_loader, desc='Train')):
        # reader readerat reader_f*8 reader_k*8 (item writer keywd*5 reg_ts maga)*N
        data = data[0].to(device)
        items = data[:,18:].contiguous().view(-1,5,9)
        item_logits = model(data[:,:18], items[:,:-1], mode=args.mode)
        loss = criterion(item_logits[:,0], items[:,-1,0].long())

        model.zero_grad()
        loss.backward()
        optimizer.step()

    if (epoch+1)%args.val_step == 0:
        with torch.no_grad():
            model.eval()
            valid_loss = 0
            for i, data in enumerate(tqdm.tqdm(valid_loader, desc='Valid')):
                data = data[0].to(device)
                items = data[:,18:].contiguous().view(-1,5,9)
                item_preds = model(data[:,:18], items[:,:-1], mode=args.mode)
                loss = criterion(item_preds[:,0], items[:,-1,0].long()).cpu().item()
                valid_loss += loss

        print('epoch: '+str(epoch+1)+' Loss: '+str(valid_loss/(i+1)))
        if best_loss > valid_loss/(i+1):
            best_loss = valid_loss/(i+1)
            best_epoch = epoch+1
            torch.save(model.state_dict(), args.save_path+'%d_rnn_attention.pkl' % (epoch+1))
    scheduler.step()
Example #5
0
    valid_loss, valid_acc = funce(model, valid_iterator, criterion)
    if valid_acc < best_valid_acc:
        patience += 1
    else:
        patience = 0

    final_valid_loss = valid_loss
    end_time = time.time()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)

    if valid_acc > best_valid_acc:
        best_valid_acc = valid_acc
        torch.save(
            {
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': valid_loss,
            }, model_name)

    myprint(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
    myprint(
        f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
    myprint(
        f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')

    train_acc_f.write(str(train_acc * 100) + '\n')
    valid_acc_f.write(str(valid_acc * 100) + '\n')
    train_acc_f.flush()
    valid_acc_f.flush()
            optimizer.zero_grad()

            input_, target = data[:, :-1], data[:, 1:]
            output, _ = model(input_)
            loss = criterion(output, target.reshape(-1))
            loss.backward()
            optimizer.step()
            loss_meter.add(loss.item())

            if (index + 1) % LOG_STEP == 0:
                print(
                    f"Epoch {epoch + 1} Batch {index + 1} Average loss {loss_meter.value()[0]}"
                )

        torch.save(
            model.state_dict(),
            op.join(LOG_PATH,
                    f"Epoch_{epoch}_AvgLoss_{loss_meter.value()[0]}.pth"))

# Interface Settings
START_WORD = "漂泊不见长安月"
USE_PREFIX = True
PREFIX_WORD = "落霞与孤鹜齐飞,秋水共长天一色。"

# Interface
if MODE == "interface":
    gen_word = list(START_WORD)
    start_word_length = len(gen_word)
    input = torch.Tensor([word2idx["<START>"]]).view(1, 1).long().to(device)
    hidden = None
    if USE_PREFIX:
Example #7
0
def train():
    global_step = 0

    # Loaded pretrained VAE
    vae = VAE(hp.vsize).to(DEVICE)
    ckpt = sorted(glob.glob(os.path.join(hp.ckpt_dir, 'vae',
                                         '*k.pth.tar')))[-1]
    vae_state = torch.load(ckpt)
    vae.load_state_dict(vae_state['model'])
    vae.eval()
    print('Loaded vae ckpt {}'.format(ckpt))

    rnn = RNN(hp.vsize, hp.asize, hp.rnn_hunits).to(DEVICE)
    ckpts = sorted(glob.glob(os.path.join(hp.ckpt_dir, 'rnn', '*k.pth.tar')))
    if ckpts:
        ckpt = ckpts[-1]
        rnn_state = torch.load(ckpt)
        rnn.load_state_dict(rnn_state['model'])
        global_step = int(os.path.basename(ckpt).split('.')[0][:-1]) * 1000
        print('Loaded rnn ckpt {}'.format(ckpt))

    data_path = hp.data_dir if not hp.extra else hp.extra_dir
    # optimizer = torch.optim.RMSprop(rnn.parameters(), lr=1e-3)
    optimizer = torch.optim.Adam(rnn.parameters(), lr=1e-4)
    dataset = GameEpisodeDataset(data_path, seq_len=hp.seq_len)
    loader = DataLoader(dataset,
                        batch_size=1,
                        shuffle=True,
                        drop_last=True,
                        num_workers=hp.n_workers,
                        collate_fn=collate_fn)
    testset = GameEpisodeDataset(data_path, seq_len=hp.seq_len, training=False)
    test_loader = DataLoader(testset,
                             batch_size=1,
                             shuffle=False,
                             drop_last=False,
                             collate_fn=collate_fn)

    ckpt_dir = os.path.join(hp.ckpt_dir, 'rnn')
    sample_dir = os.path.join(ckpt_dir, 'samples')
    os.makedirs(sample_dir, exist_ok=True)

    l1 = nn.L1Loss()

    while global_step < hp.max_step:
        # GO_states = torch.zeros([hp.batch_size, 1, hp.vsize+hp.asize]).to(DEVICE)
        with tqdm(enumerate(loader), total=len(loader), ncols=70,
                  leave=False) as t:
            t.set_description('Step {}'.format(global_step))
            for idx, (obs, actions) in t:
                obs, actions = obs.to(DEVICE), actions.to(DEVICE)
                with torch.no_grad():
                    latent_mu, latent_var = vae.encoder(obs)  # (B*T, vsize)
                    z = latent_mu
                    # z = vae.reparam(latent_mu, latent_var) # (B*T, vsize)
                    z = z.view(-1, hp.seq_len, hp.vsize)  # (B*n_seq, T, vsize)
                # import pdb; pdb.set_trace()

                next_z = z[:, 1:, :]
                z, actions = z[:, :-1, :], actions[:, :-1, :]
                states = torch.cat([z, actions], dim=-1)  # (B, T, vsize+asize)
                # states = torch.cat([GO_states, next_states[:,:-1,:]], dim=1)
                x, _, _ = rnn(states)

                loss = l1(x, next_z)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                global_step += 1

                if global_step % hp.log_interval == 0:
                    eval_loss = evaluate(test_loader, vae, rnn, global_step)
                    now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
                    with open(os.path.join(ckpt_dir, 'train.log'), 'a') as f:
                        log = '{} || Step: {}, train_loss: {:.4f}, loss: {:.4f}\n'.format(
                            now, global_step, loss.item(), eval_loss)
                        f.write(log)
                    S = 2
                    y = vae.decoder(x[S, :, :])
                    v = vae.decoder(next_z[S, :, :])
                    save_image(
                        y,
                        os.path.join(sample_dir,
                                     '{:04d}-rnn.png'.format(global_step)))
                    save_image(
                        v,
                        os.path.join(sample_dir,
                                     '{:04d}-vae.png'.format(global_step)))
                    save_image(
                        obs[S:S + hp.seq_len - 1],
                        os.path.join(sample_dir,
                                     '{:04d}-obs.png'.format(global_step)))

                if global_step % hp.save_interval == 0:
                    d = {
                        'model': rnn.state_dict(),
                        'optimizer': optimizer.state_dict(),
                    }
                    torch.save(
                        d,
                        os.path.join(
                            ckpt_dir,
                            '{:03d}k.pth.tar'.format(global_step // 1000)))
Example #8
0
def main(argv):
    global args
    args = parser.parse_args(argv)
    if args.threads == -1:
        args.threads = torch.multiprocessing.cpu_count() - 1 or 1
    print('===> Configuration')
    print(args)

    cuda = args.cuda
    if cuda:
        if torch.cuda.is_available():
            print('===> {} GPUs are available'.format(
                torch.cuda.device_count()))
        else:
            raise Exception("No GPU found, please run with --no-cuda")

    # Fix the random seed for reproducibility
    # random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if cuda:
        torch.cuda.manual_seed(args.seed)

    # Data loading
    print('===> Loading entire datasets')
    with open(args.data_path + 'train.seqs', 'rb') as f:
        train_seqs = pickle.load(f)
    with open(args.data_path + 'train.labels', 'rb') as f:
        train_labels = pickle.load(f)
    with open(args.data_path + 'valid.seqs', 'rb') as f:
        valid_seqs = pickle.load(f)
    with open(args.data_path + 'valid.labels', 'rb') as f:
        valid_labels = pickle.load(f)
    with open(args.data_path + 'test.seqs', 'rb') as f:
        test_seqs = pickle.load(f)
    with open(args.data_path + 'test.labels', 'rb') as f:
        test_labels = pickle.load(f)

    max_code = max(
        map(lambda p: max(map(lambda v: max(v), p)),
            train_seqs + valid_seqs + test_seqs))
    num_features = max_code + 1

    print("     ===> Construct train set")
    train_set = VisitSequenceWithLabelDataset(train_seqs,
                                              train_labels,
                                              num_features,
                                              reverse=False)
    print("     ===> Construct validation set")
    valid_set = VisitSequenceWithLabelDataset(valid_seqs,
                                              valid_labels,
                                              num_features,
                                              reverse=False)
    print("     ===> Construct test set")
    test_set = VisitSequenceWithLabelDataset(test_seqs,
                                             test_labels,
                                             num_features,
                                             reverse=False)

    train_loader = DataLoader(dataset=train_set,
                              batch_size=args.batch_size,
                              shuffle=True,
                              collate_fn=visit_collate_fn,
                              num_workers=args.threads)
    valid_loader = DataLoader(dataset=valid_set,
                              batch_size=args.eval_batch_size,
                              shuffle=False,
                              collate_fn=visit_collate_fn,
                              num_workers=args.threads)
    test_loader = DataLoader(dataset=test_set,
                             batch_size=args.eval_batch_size,
                             shuffle=False,
                             collate_fn=visit_collate_fn,
                             num_workers=args.threads)
    print('===> Dataset loaded!')

    # Create model
    print('===> Building a Model')

    model = RNN(dim_input=num_features, dim_emb=128, dim_hidden=128)

    if cuda:
        model = model.cuda()
    print(model)
    print('===> Model built!')

    weight_class0 = torch.mean(torch.FloatTensor(train_set.labels))
    weight_class1 = 1.0 - weight_class0
    weight = torch.FloatTensor([weight_class0, weight_class1])

    criterion = nn.CrossEntropyLoss(weight=weight)
    if args.cuda:
        criterion = criterion.cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.lr,
                                momentum=args.momentum,
                                nesterov=False,
                                weight_decay=args.weight_decay)
    scheduler = ReduceLROnPlateau(optimizer, 'min')

    best_valid_epoch = 0
    best_valid_loss = sys.float_info.max

    train_losses = []
    valid_losses = []

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    for ei in trange(args.epochs, desc="Epochs"):
        # Train
        _, _, train_loss = rnn_epoch(train_loader,
                                     model,
                                     criterion=criterion,
                                     optimizer=optimizer,
                                     train=True)
        train_losses.append(train_loss)

        # Eval
        _, _, valid_loss = rnn_epoch(valid_loader, model, criterion=criterion)
        valid_losses.append(valid_loss)

        scheduler.step(valid_loss)

        is_best = valid_loss < best_valid_loss

        if is_best:
            best_valid_epoch = ei
            best_valid_loss = valid_loss

            # evaluate on the test set
            test_y_true, test_y_pred, test_loss = rnn_epoch(
                test_loader, model, criterion=criterion)

            if args.cuda:
                test_y_true = test_y_true.cpu()
                test_y_pred = test_y_pred.cpu()

            test_auc = roc_auc_score(test_y_true.numpy(),
                                     test_y_pred.numpy()[:, 1],
                                     average="weighted")
            test_aupr = average_precision_score(test_y_true.numpy(),
                                                test_y_pred.numpy()[:, 1],
                                                average="weighted")

            with open(args.save + 'train_result.txt', 'w') as f:
                f.write('Best Validation Epoch: {}\n'.format(ei))
                f.write('Best Validation Loss: {}\n'.format(valid_loss))
                f.write('Train Loss: {}\n'.format(train_loss))
                f.write('Test Loss: {}\n'.format(test_loss))
                f.write('Test AUROC: {}\n'.format(test_auc))
                f.write('Test AUPR: {}\n'.format(test_aupr))

            torch.save(model, args.save + 'best_model.pth')
            torch.save(model.state_dict(), args.save + 'best_model_params.pth')

        # plot
        if args.plot:
            plt.figure(figsize=(12, 9))
            plt.plot(np.arange(len(train_losses)),
                     np.array(train_losses),
                     label='Training Loss')
            plt.plot(np.arange(len(valid_losses)),
                     np.array(valid_losses),
                     label='Validation Loss')
            plt.xlabel('epoch')
            plt.ylabel('Loss')
            plt.legend(loc="best")
            plt.tight_layout()
            plt.savefig(args.save + 'loss_plot.eps', format='eps')
            plt.close()

    print('Best Validation Epoch: {}\n'.format(best_valid_epoch))
    print('Best Validation Loss: {}\n'.format(best_valid_loss))
    print('Train Loss: {}\n'.format(train_loss))
    print('Test Loss: {}\n'.format(test_loss))
    print('Test AUROC: {}\n'.format(test_auc))
    print('Test AUPR: {}\n'.format(test_aupr))
Example #9
0
def main(args):
    this_dir = osp.join(osp.dirname(__file__), '.')
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    datasets = {
        phase: DataLayer(
            data_root=osp.join(args.data_root, phase),
            phase=phase,
        )
        for phase in args.phases
    }

    data_loaders = {
        phase: data.DataLoader(
            datasets[phase],
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.num_workers,
        )
        for phase in args.phases
    }

    model = Model(
        input_size=args.input_size,
        hidden_size=args.hidden_size,
        bidirectional=args.bidirectional,
        num_classes=args.num_classes,
    ).apply(utl.weights_init).to(device)
    criterion = nn.CrossEntropyLoss().to(device)
    softmax = nn.Softmax(dim=1).to(device)
    optimizer = optim.RMSprop(model.parameters(), lr=args.lr)

    for epoch in range(args.start_epoch, args.start_epoch + args.epochs):
        losses = {phase: 0.0 for phase in args.phases}
        corrects = {phase: 0.0 for phase in args.phases}

        start = time.time()
        for phase in args.phases:
            training = 'Test' not in phase
            if training:
                model.train(True)
            else:
                if epoch in args.test_intervals:
                    model.train(False)
                else:
                    continue

            with torch.set_grad_enabled(training):
                for batch_idx, (spatial, temporal, length,
                                target) in enumerate(data_loaders[phase]):
                    spatial_input = torch.zeros(*spatial.shape)
                    temporal_input = torch.zeros(*temporal.shape)
                    target_input = []
                    length_input = []

                    index = utl.argsort(length)[::-1]
                    for i, idx in enumerate(index):
                        spatial_input[i] = spatial[idx]
                        temporal_input[i] = temporal[idx]
                        target_input.append(target[idx])
                        length_input.append(length[idx])

                    spatial_input = spatial_input.to(device)
                    temporal_input = temporal_input.to(device)
                    target_input = torch.LongTensor(target_input).to(device)
                    pack1 = pack_padded_sequence(spatial_input,
                                                 length_input,
                                                 batch_first=True)
                    pack2 = pack_padded_sequence(temporal_input,
                                                 length_input,
                                                 batch_first=True)

                    score = model(pack1, pack2)
                    loss = criterion(score, target_input)
                    losses[phase] += loss.item() * target_input.shape[0]
                    if args.debug:
                        print(loss.item())

                    if training:
                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()
                    else:
                        pred = torch.max(softmax(score), 1)[1].cpu()
                        corrects[phase] += torch.sum(
                            pred == target_input.cpu()).item()
        end = time.time()

        print('Epoch {:2} | '
              'Train loss: {:.5f} Val loss: {:.5f} | '
              'Test loss: {:.5f} accuracy: {:.5f} | '
              'running time: {:.2f} sec'.format(
                  epoch,
                  losses['Train'] / len(data_loaders['Train'].dataset),
                  losses['Validation'] /
                  len(data_loaders['Validation'].dataset),
                  losses['Test'] / len(data_loaders['Test'].dataset),
                  corrects['Test'] / len(data_loaders['Test'].dataset),
                  end - start,
              ))

        if epoch in args.test_intervals:
            torch.save(
                model.state_dict(),
                osp.join(this_dir,
                         './state_dict-epoch-' + str(epoch) + '.pth'))
Example #10
0
        
        hidden_state2 = None
        output, hidden_state2 = rnn(test_inputs, hidden_state2)
        predict_price = output.data.numpy()
        pred_list.append(predict_price[0][0])

        # 加入新预测的价格,去掉顶部之前的价格
        test_dataset = np.concatenate((test_dataset, predict_price), axis=0)
        test_dataset = test_dataset[1:]
    
    # 预测的值是[0,1],将该值转换回原始值
    pred_list = np.reshape(pred_list, (-1, 1))
    pred_list = scaler.inverse_transform(pred_list)
    
    # 可视化预测未来几天的价格
    plt.plot(np.arange(len(train_y), len(train_y)+len(real_stock_price)), pred_list, 'y:')

    plt.legend(loc='best')
    plt.draw()
    plt.pause(0.05)

 # Do checkpointing
torch.save(rnn.state_dict(), '%s/time_series_rnn_model_params.pkl' % args.output)  # 只保存网络中的参数(速度快,占内存少)


plt.ioff()
plt.show()



Example #11
0
    train_ppl, train_loss = run_epoch(model, train_data, True, lr)
    # experiment.log_metric("train_loss", np.mean(train_loss), step=epoch)
    experiment.log_metric("train_perplexity", train_ppl, step=epoch)

    # RUN MODEL ON VALIDATION DATA
    val_ppl, val_loss = run_epoch(model, valid_data)
    # experiment.log_metric("val_loss", np.mean(val_loss), step=epoch)
    experiment.log_metric("val_perplexity", val_ppl, step=epoch)

    # SAVE MODEL IF IT'S THE BEST SO FAR
    if val_ppl < best_val_so_far:
        best_val_so_far = val_ppl
        if args.save_best:
            print("Saving model parameters to best_params.pt")
            best_model_path = os.path.join(args.save_dir, 'best_params.pt')
            torch.save(model.state_dict(), best_model_path)
            experiment.log_asset(best_model_path, overwrite=True)
        # NOTE ==============================================
        # You will need to load these parameters into the same model
        # for a couple Problems: so that you can compute the gradient
        # of the loss w.r.t. hidden state as required in Problem 5.2
        # and to sample from the the model as required in Problem 5.3
        # We are not asking you to run on the test data, but if you
        # want to look at test performance you would load the saved
        # model and run on the test data with batch_size=1

    # LOC RESULTS
    train_ppls.append(train_ppl)
    val_ppls.append(val_ppl)
    train_losses.extend(train_loss)
    val_losses.extend(val_loss)
Example #12
0
    train_compressed_signal, _, acc_train = eval_RNN_Model(
        train_loader, time_step, input_size, model, num_classes, criterion,
        "train", path)
    test_compressed_signal, loss_test, acc_test = eval_RNN_Model(
        test_loader, time_step, input_size, model, num_classes, criterion,
        "test", path)

    # anneal learning
    scheduler.step(loss_test)

    if acc_train > best_acc_train:
        save_path = os.path.join(path, 'compressed_train_GRU')
        np.save(save_path, train_compressed_signal)

    if acc_test > best_acc_test:
        save_path = os.path.join(path, 'compressed_test_GRU')
        np.save(save_path, test_compressed_signal)

    is_best = acc_test > best_acc_test
    best_acc_train = max(acc_train, best_acc_train)
    best_acc_test = max(acc_test, best_acc_test)
    save_checkpoint(
        {
            'epoch': epoch + 1,
            'arch': arch,
            'state_dict': model.state_dict(),
            'best_acc1': best_acc_test,
            'optimizer': op.state_dict(),
        }, is_best)
Example #13
0
class Trainer:
    """
    训练
    """
    def __init__(self, _hparams):
        utils.set_seed(_hparams.fixed_seed)

        self.train_loader = get_train_loader(_hparams)
        self.val_loader = get_val_loader(_hparams)
        self.encoder = CNN().to(DEVICE)
        self.decoder = RNN(fea_dim=_hparams.fea_dim,
                           embed_dim=_hparams.embed_dim,
                           hid_dim=_hparams.hid_dim,
                           max_sen_len=_hparams.max_sen_len,
                           vocab_pkl=_hparams.vocab_pkl).to(DEVICE)
        self.loss_fn = nn.CrossEntropyLoss()
        self.optimizer = torch.optim.Adam(self.get_params(), lr=_hparams.lr)
        self.writer = SummaryWriter()

        self.max_sen_len = _hparams.max_sen_len
        self.val_cap = _hparams.val_cap
        self.ft_encoder_lr = _hparams.ft_encoder_lr
        self.ft_decoder_lr = _hparams.ft_decoder_lr
        self.best_CIDEr = 0

    def fine_tune_encoder(self, fine_tune_epochs, val_interval, save_path,
                          val_path):
        print('*' * 20, 'fine tune encoder for', fine_tune_epochs, 'epochs',
              '*' * 20)
        self.encoder.fine_tune()
        self.optimizer = torch.optim.Adam([
            {
                'params': self.encoder.parameters(),
                'lr': self.ft_encoder_lr
            },
            {
                'params': self.decoder.parameters(),
                'lr': self.ft_decoder_lr
            },
        ])
        self.training(fine_tune_epochs, val_interval, save_path, val_path)
        self.encoder.froze()
        print('*' * 20, 'fine tune encoder complete', '*' * 20)

    def get_params(self):
        """
        模型需要优化的全部参数,此处encoder暂时设计不用训练,故不加参数

        :return:
        """
        return list(self.decoder.parameters())

    def training(self, max_epochs, val_interval, save_path, val_path):
        """
        训练

        :param val_path: 保存验证过程生成句子的路径
        :param save_path: 保存模型的地址
        :param val_interval: 验证的间隔
        :param max_epochs: 最大训练的轮次
        :return:
        """
        print('*' * 20, 'train', '*' * 20)
        for epoch in range(max_epochs):
            self.set_train()

            epoch_loss = 0
            epoch_steps = len(self.train_loader)
            for step, (img, cap,
                       cap_len) in tqdm(enumerate(self.train_loader)):
                # batch_size * 3 * 224 * 224
                img = img.to(DEVICE)
                cap = cap.to(DEVICE)

                self.optimizer.zero_grad()

                features = self.encoder.forward(img)
                outputs = self.decoder.forward(features, cap)

                outputs = pack_padded_sequence(outputs,
                                               cap_len - 1,
                                               batch_first=True)[0]
                targets = pack_padded_sequence(cap[:, 1:],
                                               cap_len - 1,
                                               batch_first=True)[0]
                train_loss = self.loss_fn(outputs, targets)
                epoch_loss += train_loss.item()
                train_loss.backward()
                self.optimizer.step()

            epoch_loss /= epoch_steps
            self.writer.add_scalar('epoch_loss', epoch_loss, epoch)
            print('epoch_loss: {}, epoch: {}'.format(epoch_loss, epoch))
            if (epoch + 1) % val_interval == 0:
                CIDEr = self.validating(epoch, val_path)
                if self.best_CIDEr <= CIDEr:
                    self.best_CIDEr = CIDEr
                    self.save_model(save_path, epoch)

    def save_model(self, save_path, train_epoch):
        """
        保存最好的模型

        :param save_path: 保存模型文件的地址
        :param train_epoch: 当前训练的轮次
        :return:
        """
        model_state_dict = {
            'encoder_state_dict': self.encoder.state_dict(),
            'decoder_state_dict': self.decoder.state_dict(),
            'tran_epoch': train_epoch,
        }
        print('*' * 20, 'save model to: ', save_path, '*' * 20)
        torch.save(model_state_dict, save_path)

    def validating(self, train_epoch, val_path):
        """
        验证

        :param val_path: 保存验证过程生成句子的路径
        :param train_epoch: 当前训练的epoch
        :return:
        """
        print('*' * 20, 'validate', '*' * 20)
        self.set_eval()
        sen_json = []
        with torch.no_grad():
            for val_step, (img, img_id) in tqdm(enumerate(self.val_loader)):
                img = img.to(DEVICE)
                features = self.encoder.forward(img)
                sens, _ = self.decoder.sample(features)
                sen_json.append({'image_id': int(img_id), 'caption': sens[0]})

        with open(val_path, 'w') as f:
            json.dump(sen_json, f)

        result = coco_eval(self.val_cap, val_path)
        scores = {}
        for metric, score in result:
            scores[metric] = score
            self.writer.add_scalar(metric, score, train_epoch)

        return scores['CIDEr']

    def set_train(self):
        self.encoder.train()
        self.decoder.train()

    def set_eval(self):
        self.encoder.eval()
        self.decoder.eval()