Esempio n. 1
0
def train(epochs=50_000):
    tensorboard_log_folder = f"runs/copy-task-{datetime.now().strftime('%Y-%m-%dT%H%M%S')}"
    writer = SummaryWriter(tensorboard_log_folder)
    print(f"Training for {epochs} epochs, logging in {tensorboard_log_folder}")
    sequence_min_length = 1
    sequence_max_length = 20
    vector_length = 8
    memory_size = (128, 20)
    hidden_layer_size = 100
    batch_size = 4
    lstm_controller = not args.ff

    writer.add_scalar("sequence_min_length", sequence_min_length)
    writer.add_scalar("sequence_max_length", sequence_max_length)
    writer.add_scalar("vector_length", vector_length)
    writer.add_scalar("memory_size0", memory_size[0])
    writer.add_scalar("memory_size1", memory_size[1])
    writer.add_scalar("hidden_layer_size", hidden_layer_size)
    writer.add_scalar("lstm_controller", lstm_controller)
    writer.add_scalar("seed", seed)
    writer.add_scalar("batch_size", batch_size)

    model = NTM(vector_length, hidden_layer_size, memory_size, lstm_controller)

    optimizer = optim.RMSprop(model.parameters(),
                              momentum=0.9,
                              alpha=0.95,
                              lr=1e-4)
    feedback_frequency = 100
    total_loss = []
    total_cost = []

    os.makedirs("models", exist_ok=True)
    if os.path.isfile(model_path):
        print(f"Loading model from {model_path}")
        checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
        model.load_state_dict(checkpoint)

    for epoch in range(epochs + 1):
        optimizer.zero_grad()
        input, target = get_training_sequence(sequence_min_length,
                                              sequence_max_length,
                                              vector_length, batch_size)
        state = model.get_initial_state(batch_size)
        for vector in input:
            _, state = model(vector, state)
        y_out = torch.zeros(target.size())
        for j in range(len(target)):
            y_out[j], state = model(torch.zeros(batch_size, vector_length + 1),
                                    state)
        loss = F.binary_cross_entropy(y_out, target)
        loss.backward()
        optimizer.step()
        total_loss.append(loss.item())
        y_out_binarized = y_out.clone().data
        y_out_binarized.apply_(lambda x: 0 if x < 0.5 else 1)
        cost = torch.sum(torch.abs(y_out_binarized - target)) / len(target)
        total_cost.append(cost.item())
        if epoch % feedback_frequency == 0:
            running_loss = sum(total_loss) / len(total_loss)
            running_cost = sum(total_cost) / len(total_cost)
            print(f"Loss at step {epoch}: {running_loss}")
            writer.add_scalar('training loss', running_loss, epoch)
            writer.add_scalar('training cost', running_cost, epoch)
            total_loss = []
            total_cost = []

    torch.save(model.state_dict(), model_path)