示例#1
0
def train_model(no_epochs):

    batch_size = 16
    data_loaders = Data_Loaders(batch_size)
    model = Action_Conditioned_FF()

    loss_function = torch.nn.BCEWithLogitsLoss()
    # loss_function = torch.nn.MSELoss()
    losses = []
    min_loss = model.evaluate(model, data_loaders.test_loader, loss_function)
    losses.append(min_loss)
    # print(min_loss)

    optimizer = Adam(model.parameters(), lr=0.001)
    prev_loss = float('inf')
    for epoch_i in tqdm(range(no_epochs)):
        model.train()
        for idx, sample in enumerate(data_loaders.train_loader
                                     ):  # sample['input'] and sample['label']
            x, y = sample['input'], sample['label']
            optimizer.zero_grad()
            y_hat = model.forward(x)
            # loss = loss_function(y_hat.unsqueeze(dim=0), y.long())
            loss = loss_function(y_hat.reshape(1).float(), y.float())
            loss.backward()
            optimizer.step()
        total_loss = model.evaluate(model, data_loaders.test_loader,
                                    loss_function)
        if total_loss < prev_loss:
            torch.save(model.state_dict(),
                       "saved/saved_model.pkl",
                       _use_new_zipfile_serialization=False)
            prev_loss = total_loss
        print()
        print(total_loss)
        losses.append(
            model.evaluate(model, data_loaders.test_loader, loss_function))
        # print('e')
    torch.save(model.state_dict(),
               "saved/saved_model.pkl",
               _use_new_zipfile_serialization=False)
示例#2
0
def train_model(no_epochs):
    batch_size = 32
    data_loaders = dl.Data_Loaders(batch_size)
    model = Action_Conditioned_FF()
    loss_fn = nn.BCELoss()
    optimizer = torch.optim.Adam(model.model.parameters(), lr=0.001)

    test_losses = []
    train_losses = []

    for epoch_i in range(no_epochs):
        #print(f"epoch # {epoch_i + 1}")
        model.train()
        loss_test = model.evaluate(model, data_loaders.test_loader, loss_fn)
        print('test')
        print(loss_test)
        test_losses.append(loss_test)
        l=[]

        for idx, sample in enumerate(data_loaders.train_loader):  # sample['input'] and sample['label']
            input, label = sample['input'].float(), sample['label'].float()

            #forward step
            out = model.forward(input)
            loss_train = loss_fn(out, label.view(-1,1))
            l.append(loss_train)

            #backpropagation
            optimizer.zero_grad()
            loss_train.backward()
            optimizer.step()
        print('train')
        print(sum(l)/len(l))
        train_losses.append(sum(l) / len(l))

    torch.save(model.state_dict(), 'saved_model.pkl', _use_new_zipfile_serialization=False)