Пример #1
0
def handler(event, context):
    start_time = time.time()
    bucket = event['data_bucket']
    worker_index = event['rank']
    num_workers = event['num_workers']
    key = 'training_{}.pt'.format(worker_index)

    print('data_bucket = {}\n worker_index:{}\n num_worker:{}\n key:{}'.format(bucket, worker_index, num_workers, key))

    # Set thrift connection
    # Make socket
    transport = TSocket.TSocket(constants.HOST, constants.PORT)
    # Buffering is critical. Raw sockets are very slow
    transport = TTransport.TBufferedTransport(transport)
    # Wrap in a protocol
    protocol = TBinaryProtocol.TBinaryProtocol(transport)
    # Create a client to use the protocol encoder
    t_client = ParameterServer.Client(protocol)
    # Connect!
    transport.open()

    # test thrift connection
    ps_client.ping(t_client)
    print("create and ping thrift server >>> HOST = {}, PORT = {}"
          .format(constants.HOST, constants.PORT))

    # read file from s3
    readS3_start = time.time()
    s3.Bucket(bucket).download_file(key, os.path.join(local_dir, training_file))
    s3.Bucket(bucket).download_file(test_file, os.path.join(local_dir, test_file))
    print("read data cost {} s".format(time.time() - readS3_start))


    # preprocess dataset
    preprocess_start = time.time()

    trainset = torch.load(os.path.join(local_dir, training_file))
    testset = torch.load(os.path.join(local_dir, test_file))
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
    testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False)
    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
    print("preprocess data cost {} s".format(time.time() - preprocess_start))

    device = 'cpu'
    torch.manual_seed(1234)

    #Model
    print('==> Building model..')
    # net = VGG('VGG19')
    # net = ResNet18()
    # net = ResNet50()
    # net = PreActResNet18()
    # net = GoogLeNet()
    # net = DenseNet121()
    # net = ResNeXt29_2x64d()
    net = MobileNet()
    # net = MobileNetV2()
    # net = DPN92()
    # net = ShuffleNetG2()
    # net = SENet18()
    # net = ShuffleNetV2(1)
    # net = EfficientNetB0()

    net = net.to(device)

    optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4)


    # Loss and Optimizer
    # Softmax is internally computed.
    # Set parameters to be updated.
    # criterion = torch.nn.CrossEntropyLoss()
    # optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE)

    # register model
    model_name = "dnn"
    weight = [param.data.numpy() for param in net.parameters()]
    weight_shape = [layer.shape for layer in weight]
    weight_size = [layer.size for layer in weight]
    weight_length = sum(weight_size)

    ps_client.register_model(t_client, worker_index, model_name, weight_length, num_workers)
    ps_client.exist_model(t_client, model_name)

    print("register and check model >>> name = {}, length = {}".format(model_name, weight_length))

    # Training the Model
    train_start = time.time()
    iter_counter = 0

    for epoch in range(NUM_EPOCHS):

        epoch_start = time.time()

        net.train()
        num_batch = 0
        train_acc = Accuracy()
        train_loss = Average()

        for batch_idx, (inputs, targets) in enumerate(trainloader):

            # print("------worker {} epoch {} batch {}------".format(worker_index, epoch+1, batch_idx+1))
            batch_start = time.time()
            
            # pull latest model
            pull_start = time.time()
            ps_client.can_pull(t_client, model_name, iter_counter, worker_index)
            latest_model = ps_client.pull_model(t_client, model_name, iter_counter, worker_index)
            latest_model = np.asarray(latest_model,dtype=np.float32)
            pull_time = time.time() - pull_start

            # update the model
            offset = 0
            for layer_index, param in enumerate(net.parameters()):

                layer_value = latest_model[offset : offset + weight_size[layer_index]].reshape(weight_shape[layer_index])
                param.data = torch.from_numpy(layer_value)

                offset += weight_size[layer_index]
    
            # Forward + Backward + Optimize
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = F.cross_entropy(outputs, targets)
            
            optimizer.zero_grad()
            loss.backward()
            
            train_acc.update(outputs, targets)
            train_loss.update(loss.item(), inputs.size(0))
            
            # flatten and concat gradients of weight and bias
            for index, param in enumerate(net.parameters()):
                if index == 0:
                    flattened_grad = param.grad.data.numpy().flatten()
                else:
                    flattened_grad = np.concatenate((flattened_grad, param.grad.data.numpy().flatten()))
    
            flattened_grad = flattened_grad * -1
            
            # push gradient to PS
            push_start = time.time()
            ps_client.can_push(t_client, model_name, iter_counter, worker_index)
            ps_client.push_grad(t_client, model_name, flattened_grad, learning_rate, iter_counter, worker_index)
            ps_client.can_pull(t_client, model_name, iter_counter+1, worker_index)      # sync all workers
            push_time = time.time() - push_start

            iter_counter += 1
            num_batch += 1
            
            step_time = time.time() - batch_start
            
            print("Epoch:[{}/{}], Step:[{}/{}];\n Training Loss:{}, Training accuracy:{};\n Step Time:{}, Calculation Time:{}, Communication Time:{}".format(
                epoch, NUM_EPOCHS, num_batch, len(trainloader), train_loss, train_acc, step_time, step_time - (pull_time + push_time), pull_time + push_time))

        # Test the Model
        net.eval()
        test_loss = Average()
        test_acc = Accuracy()
        
        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(testloader):
                
                inputs, targets = inputs.to(device), targets.to(device)
                
                outputs = net(inputs)
                
                loss = F.cross_entropy(outputs, targets)
                
                test_loss.update(loss.item(), inputs.size(0))
                test_acc.update(outputs, targets)
        # correct = 0
        # total = 0
        # test_loss = 0
        # for items, labels in validation_loader:
        #     items = Variable(items.view(-1, NUM_FEATURES))
        #     labels = Variable(labels)
        #     outputs = model(items)
        #     test_loss += criterion(outputs, labels).data
        #     _, predicted = torch.max(outputs.data, 1)
        #     total += labels.size(0)
        #     correct += (predicted == labels).sum()

        print('Time = %.4f, accuracy of the model on test set: %f, loss = %f'
              % (time.time() - train_start, test_acc, test_loss))

    end_time = time.time()
    print("Elapsed time = {} s".format(end_time - start_time))
Пример #2
0
def handler(event, context):
    start_time = time.time()
    bucket = event['bucket_name']
    worker_index = event['rank']
    num_workers = event['num_workers']
    key = event['file']
    host = event['host']
    port = event['port']

    print('bucket = {}'.format(bucket))
    print('number of workers = {}'.format(num_workers))
    print('worker index = {}'.format(worker_index))
    print("file = {}".format(key))
    print("host = {}".format(host))
    print("port = {}".format(port))

    # Set thrift connection
    # Make socket
    transport = TSocket.TSocket(host, port)
    # Buffering is critical. Raw sockets are very slow
    transport = TTransport.TBufferedTransport(transport)
    # Wrap in a protocol
    protocol = TBinaryProtocol.TBinaryProtocol(transport)
    # Create a client to use the protocol encoder
    t_client = ParameterServer.Client(protocol)
    # Connect!
    transport.open()

    # test thrift connection
    ps_client.ping(t_client)
    print("create and ping thrift server >>> HOST = {}, PORT = {}".format(
        host, port))

    # read file from s3
    file = get_object(bucket, key).read().decode('utf-8').split("\n")
    print("read data cost {} s".format(time.time() - start_time))

    # parse dataset
    parse_start = time.time()
    dataset = DenseDatasetWithLines(file, NUM_FEATURES)
    print("parse data cost {} s".format(time.time() - parse_start))

    # preprocess dataset
    preprocess_start = time.time()
    dataset_size = len(dataset)
    indices = list(
        range(dataset_size))  # indices for training and validation splits:
    split = int(np.floor(VALIDATION_RATIO * dataset_size))
    if SHUFFLE_DATASET:
        np.random.seed(RANDOM_SEED)
        np.random.shuffle(indices)
    train_indices, val_indices = indices[split:], indices[:split]
    # Creating PT data samplers and loaders:
    train_sampler = SubsetRandomSampler(train_indices)
    valid_sampler = SubsetRandomSampler(val_indices)
    train_loader = torch.utils.data.DataLoader(dataset,
                                               batch_size=BATCH_SIZE,
                                               sampler=train_sampler)
    validation_loader = torch.utils.data.DataLoader(dataset,
                                                    batch_size=BATCH_SIZE,
                                                    sampler=valid_sampler)
    print("preprocess data cost {} s".format(time.time() - preprocess_start))

    model = LogisticRegression(NUM_FEATURES, NUM_CLASSES)

    # Loss and Optimizer
    # Softmax is internally computed.
    # Set parameters to be updated.
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE)

    # register model
    model_name = "w.b"
    weight_shape = model.linear.weight.data.numpy().shape
    weight_length = weight_shape[0] * weight_shape[1]
    bias_shape = model.linear.bias.data.numpy().shape
    bias_length = bias_shape[0]
    model_length = weight_length + bias_length
    ps_client.register_model(t_client, worker_index, model_name, model_length,
                             num_workers)
    ps_client.exist_model(t_client, model_name)
    print("register and check model >>> name = {}, length = {}".format(
        model_name, model_length))

    # Training the Model
    train_start = time.time()
    iter_counter = 0
    for epoch in range(NUM_EPOCHS):
        epoch_start = time.time()
        for batch_index, (items, labels) in enumerate(train_loader):
            print("------worker {} epoch {} batch {}------".format(
                worker_index, epoch, batch_index))
            batch_start = time.time()

            # pull latest model
            ps_client.can_pull(t_client, model_name, iter_counter,
                               worker_index)
            latest_model = ps_client.pull_model(t_client, model_name,
                                                iter_counter, worker_index)
            model.linear.weight = Parameter(
                torch.from_numpy(
                    np.asarray(latest_model[:weight_length],
                               dtype=np.double).reshape(weight_shape)))
            model.linear.bias = Parameter(
                torch.from_numpy(
                    np.asarray(latest_model[weight_length:],
                               dtype=np.double).reshape(bias_shape[0])))

            items = Variable(items.view(-1, NUM_FEATURES))
            labels = Variable(labels)

            # Forward + Backward + Optimize
            optimizer.zero_grad()
            outputs = model(items.double())
            loss = criterion(outputs, labels)
            loss.backward()

            # flatten and concat gradients of weight and bias
            w_b_grad = np.concatenate(
                (model.linear.weight.grad.data.numpy().flatten(),
                 model.linear.bias.grad.data.numpy().flatten()))
            cal_time = time.time() - batch_start

            # push gradient to PS
            sync_start = time.time()
            ps_client.can_push(t_client, model_name, iter_counter,
                               worker_index)
            ps_client.push_grad(t_client, model_name, w_b_grad, LEARNING_RATE,
                                iter_counter, worker_index)
            ps_client.can_pull(t_client, model_name, iter_counter + 1,
                               worker_index)  # sync all workers
            sync_time = time.time() - sync_start

            print(
                'Epoch: [%d/%d], Step: [%d/%d] >>> Time: %.4f, Loss: %.4f, epoch cost %.4f, '
                'batch cost %.4f s: cal cost %.4f s and communication cost %.4f s'
                % (epoch + 1, NUM_EPOCHS, batch_index + 1,
                   len(train_indices) / BATCH_SIZE, time.time() - train_start,
                   loss.data, time.time() - epoch_start,
                   time.time() - batch_start, cal_time, sync_time))
            iter_counter += 1

        # Test the Model
        correct = 0
        total = 0
        test_loss = 0
        for items, labels in validation_loader:
            items = Variable(items.view(-1, NUM_FEATURES))
            labels = Variable(labels)
            outputs = model(items)
            test_loss += criterion(outputs, labels).data
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum()

        print(
            'Time = %.4f, accuracy of the model on the %d test samples: %d %%, loss = %f'
            % (time.time() - train_start, len(val_indices),
               100 * correct / total, test_loss))

    end_time = time.time()
    print("Elapsed time = {} s".format(end_time - start_time))
Пример #3
0
def handler(event, context):
    # dataset
    data_bucket = event['data_bucket']
    file = event['file']
    dataset_type = event["dataset_type"]
    assert dataset_type == "sparse_libsvm"
    n_features = event['n_features']

    # ps setting
    host = event['host']
    port = event['port']

    # hyper-parameter
    n_clusters = event['n_clusters']
    n_epochs = event["n_epochs"]
    threshold = event["threshold"]
    sync_mode = event["sync_mode"]
    n_workers = event["n_workers"]
    worker_index = event['worker_index']
    assert sync_mode.lower() == Synchronization.Reduce

    print('data bucket = {}'.format(data_bucket))
    print("file = {}".format(file))
    print('number of workers = {}'.format(n_workers))
    print('worker index = {}'.format(worker_index))
    print('num clusters = {}'.format(n_clusters))
    print('host = {}'.format(host))
    print('port = {}'.format(port))

    # Set thrift connection
    # Make socket
    transport = TSocket.TSocket(host, port)
    # Buffering is critical. Raw sockets are very slow
    transport = TTransport.TBufferedTransport(transport)
    # Wrap in a protocol
    protocol = TBinaryProtocol.TBinaryProtocol(transport)
    # Create a client to use the protocol encoder
    t_client = ParameterServer.Client(protocol)
    # Connect!
    transport.open()
    # test thrift connection
    ps_client.ping(t_client)
    print("create and ping thrift server >>> HOST = {}, PORT = {}".format(host, port))

    # Reading data from S3
    read_start = time.time()
    storage = S3Storage()
    lines = storage.load(file, data_bucket).read().decode('utf-8').split("\n")
    print("read data cost {} s".format(time.time() - read_start))

    parse_start = time.time()
    dataset = libsvm_dataset.from_lines(lines, n_features, dataset_type)
    train_set = dataset.ins_list
    np_dtype = train_set[0].to_dense().numpy().dtype
    centroid_shape = (n_clusters, n_features)
    print("parse data cost {} s".format(time.time() - parse_start))
    print("dataset type: {}, data type: {}, centroids shape: {}"
          .format(dataset_type, np_dtype, centroid_shape))

    # register model
    model_name = Prefix.KMeans_Cent
    model_length = centroid_shape[0] * centroid_shape[1] + 1
    ps_client.register_model(t_client, worker_index, model_name, model_length, n_workers)
    ps_client.exist_model(t_client, model_name)
    print("register and check model >>> name = {}, length = {}".format(model_name, model_length))

    init_centroids_start = time.time()
    ps_client.can_pull(t_client, model_name, 0, worker_index)
    ps_model = ps_client.pull_model(t_client, model_name, 0, worker_index)
    if worker_index == 0:
        centroids_np = sparse_centroid_to_numpy(train_set[0:n_clusters], n_clusters)
        ps_client.can_push(t_client, model_name, 0, worker_index)
        ps_client.push_grad(t_client, model_name,
                            np.append(centroids_np.flatten(), 1000.).astype(np.double) - np.asarray(ps_model).astype(np.double),
                            1., 0, worker_index)
    else:
        centroids_np = np.zeros(centroid_shape)
        ps_client.can_push(t_client, model_name, 0, worker_index)
        ps_client.push_grad(t_client, model_name,
                            np.append(centroids_np.flatten(), 0).astype(np.double),
                            0, 0, worker_index)
    ps_client.can_pull(t_client, model_name, 1, worker_index)
    ps_model = ps_client.pull_model(t_client, model_name, 1, worker_index)
    cur_centroids = np.array(ps_model[0:-1]).astype(np.float32).reshape(centroid_shape)
    cur_error = float(ps_model[-1])
    print("initial centroids cost {} s".format(time.time() - init_centroids_start))

    model = cluster_models.get_model(train_set, torch.from_numpy(cur_centroids), dataset_type,
                                     n_features, n_clusters)

    train_start = time.time()
    for epoch in range(1, n_epochs + 1):
        epoch_start = time.time()

        # local computation
        model.find_nearest_cluster()
        local_cent = model.get_centroids("numpy").reshape(-1)
        local_cent_error = np.concatenate((local_cent.astype(np.double).flatten(),
                                           np.array([model.error], dtype=np.double)))
        epoch_cal_time = time.time() - epoch_start

        # push updates
        epoch_comm_start = time.time()
        last_cent_error = np.concatenate((cur_centroids.astype(np.double).flatten(),
                                          np.array([cur_error], dtype=np.double)))
        ps_model_inc = local_cent_error - last_cent_error
        ps_client.can_push(t_client, model_name, epoch, worker_index)
        ps_client.push_grad(t_client, model_name,
                            ps_model_inc, 1. / n_workers, epoch, worker_index)

        # pull new model
        ps_client.can_pull(t_client, model_name, epoch + 1, worker_index)   # sync all workers
        ps_model = ps_client.pull_model(t_client, model_name, epoch + 1, worker_index)
        model.centroids = [torch.from_numpy(c).reshape(1, n_features).to_sparse()
                           for c in np.array(ps_model[0:-1]).astype(np.float32).reshape(centroid_shape)]

        model.error = float(ps_model[-1])
        cur_centroids = model.get_centroids("numpy")
        cur_error = model.error

        epoch_comm_time = time.time() - epoch_comm_start

        print("Epoch[{}] Worker[{}], error = {}, cost {} s, cal cost {} s, sync cost {} s"
              .format(epoch, worker_index, model.error,
                      time.time() - epoch_start, epoch_cal_time, epoch_comm_time))

        if model.error < threshold:
            break

    print("Worker[{}] finishes training: Error = {}, cost {} s"
          .format(worker_index, model.error, time.time() - train_start))
    return
Пример #4
0
def handler(event, context):
    start_time = time.time()

    # dataset setting
    train_file = event['train_file']
    test_file = event['test_file']
    data_bucket = event['data_bucket']
    n_features = event['n_features']
    n_classes = event['n_classes']
    n_workers = event['n_workers']
    worker_index = event['worker_index']
    cp_bucket = event['cp_bucket']

    # ps setting
    host = event['host']
    port = event['port']

    # training setting
    model_name = event['model']
    optim = event['optim']
    sync_mode = event['sync_mode']
    assert model_name.lower() in MLModel.Deep_Models
    assert optim.lower() in Optimization.Grad_Avg
    assert sync_mode.lower() in Synchronization.Reduce

    # hyper-parameter
    learning_rate = event['lr']
    batch_size = event['batch_size']
    n_epochs = event['n_epochs']
    start_epoch = event['start_epoch']
    run_epochs = event['run_epochs']

    function_name = event['function_name']

    print('data bucket = {}'.format(data_bucket))
    print("train file = {}".format(train_file))
    print("test file = {}".format(test_file))
    print('number of workers = {}'.format(n_workers))
    print('worker index = {}'.format(worker_index))
    print('model = {}'.format(model_name))
    print('optimization = {}'.format(optim))
    print('sync mode = {}'.format(sync_mode))
    print('start epoch = {}'.format(start_epoch))
    print('run epochs = {}'.format(run_epochs))
    print('host = {}'.format(host))
    print('port = {}'.format(port))

    print("Run function {}, round: {}/{}, epoch: {}/{} to {}/{}".format(
        function_name,
        int(start_epoch / run_epochs) + 1, math.ceil(n_epochs / run_epochs),
        start_epoch + 1, n_epochs, start_epoch + run_epochs, n_epochs))

    # download file from s3
    storage = S3Storage()
    local_dir = "/tmp"
    read_start = time.time()
    storage.download(data_bucket, train_file,
                     os.path.join(local_dir, train_file))
    storage.download(data_bucket, test_file,
                     os.path.join(local_dir, test_file))
    print("download file from s3 cost {} s".format(time.time() - read_start))

    train_set = torch.load(os.path.join(local_dir, train_file))
    test_set = torch.load(os.path.join(local_dir, test_file))
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=batch_size,
                                               shuffle=True)
    n_train_batch = len(train_loader)
    test_loader = torch.utils.data.DataLoader(test_set,
                                              batch_size=100,
                                              shuffle=False)
    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse',
               'ship', 'truck')

    print("read data cost {} s".format(time.time() - read_start))

    random_seed = 100
    torch.manual_seed(random_seed)

    device = 'cpu'
    model = deep_models.get_models(model_name).to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

    # load checkpoint model if it is not the first round
    if start_epoch != 0:
        checked_file = 'checkpoint_{}.pt'.format(start_epoch - 1)
        storage.download(cp_bucket, checked_file,
                         os.path.join(local_dir, checked_file))
        checkpoint_model = torch.load(os.path.join(local_dir, checked_file))

        model.load_state_dict(checkpoint_model['model_state_dict'])
        optimizer.load_state_dict(checkpoint_model['optimizer_state_dict'])
        print("load checkpoint model at epoch {}".format(start_epoch - 1))

    # Set thrift connection
    # Make socket
    transport = TSocket.TSocket(host, port)
    # Buffering is critical. Raw sockets are very slow
    transport = TTransport.TBufferedTransport(transport)
    # Wrap in a protocol
    protocol = TBinaryProtocol.TBinaryProtocol(transport)
    # Create a client to use the protocol encoder
    t_client = ParameterServer.Client(protocol)
    # Connect!
    transport.open()
    # test thrift connection
    ps_client.ping(t_client)
    print("create and ping thrift server >>> HOST = {}, PORT = {}".format(
        host, port))

    # register model
    parameter_shape = []
    parameter_length = []
    model_length = 0
    for param in model.parameters():
        tmp_shape = 1
        parameter_shape.append(param.data.numpy().shape)
        for w in param.data.numpy().shape:
            tmp_shape *= w
        parameter_length.append(tmp_shape)
        model_length += tmp_shape

    ps_client.register_model(t_client, worker_index, model_name, model_length,
                             n_workers)
    ps_client.exist_model(t_client, model_name)
    print("register and check model >>> name = {}, length = {}".format(
        model_name, model_length))

    # Training the Model
    train_start = time.time()
    iter_counter = 0
    for epoch in range(start_epoch, min(start_epoch + run_epochs, n_epochs)):

        model.train()
        epoch_start = time.time()

        train_acc = Accuracy()
        train_loss = Average()

        for batch_idx, (inputs, targets) in enumerate(train_loader):
            batch_start = time.time()
            batch_cal_time = 0
            batch_comm_time = 0

            # pull latest model
            ps_client.can_pull(t_client, model_name, iter_counter,
                               worker_index)
            latest_model = ps_client.pull_model(t_client, model_name,
                                                iter_counter, worker_index)
            pos = 0
            for layer_index, param in enumerate(model.parameters()):
                param.data = Variable(
                    torch.from_numpy(
                        np.asarray(latest_model[pos:pos +
                                                parameter_length[layer_index]],
                                   dtype=np.float32).reshape(
                                       parameter_shape[layer_index])))
                pos += parameter_length[layer_index]
            batch_comm_time += time.time() - batch_start

            batch_cal_start = time.time()
            outputs = model(inputs)
            loss = F.cross_entropy(outputs, targets)
            optimizer.zero_grad()
            loss.backward()

            # flatten and concat gradients of weight and bias
            param_grad = np.zeros((1))
            for param in model.parameters():
                # print("shape of layer = {}".format(param.data.numpy().flatten().shape))
                param_grad = np.concatenate(
                    (param_grad, param.data.numpy().flatten()))
            param_grad = np.delete(param_grad, 0)
            #print("model_length = {}".format(param_grad.shape))
            batch_cal_time += time.time() - batch_cal_start

            # push gradient to PS
            batch_push_start = time.time()
            ps_client.can_push(t_client, model_name, iter_counter,
                               worker_index)
            ps_client.push_grad(t_client, model_name, param_grad,
                                -1. * learning_rate / n_workers, iter_counter,
                                worker_index)
            ps_client.can_pull(t_client, model_name, iter_counter + 1,
                               worker_index)  # sync all workers
            batch_comm_time += time.time() - batch_push_start

            train_acc.update(outputs, targets)
            train_loss.update(loss.item(), inputs.size(0))

            optimizer.step()
            iter_counter += 1

            if batch_idx % 10 == 0:
                print(
                    'Epoch: [%d/%d], Batch: [%d/%d], Time: %.4f, Loss: %.4f, epoch cost %.4f, '
                    'batch cost %.4f s: cal cost %.4f s and communication cost %.4f s'
                    % (epoch + 1, n_epochs, batch_idx + 1, n_train_batch,
                       time.time() - train_start, loss.item(),
                       time.time() - epoch_start, time.time() - batch_start,
                       batch_cal_time, batch_comm_time))

        test_loss, test_acc = test(epoch, model, test_loader)

        print(
            'Epoch: {}/{},'.format(epoch + 1, n_epochs),
            'train loss: {},'.format(train_loss),
            'train acc: {},'.format(train_acc),
            'test loss: {},'.format(test_loss),
            'test acc: {}.'.format(test_acc),
        )

    # training is not finished yet, invoke next round
    if epoch < n_epochs - 1:
        checkpoint_model = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': train_loss.average
        }

        checked_file = 'checkpoint_{}.pt'.format(epoch)

        if worker_index == 0:
            torch.save(checkpoint_model, os.path.join(local_dir, checked_file))
            storage.upload(cp_bucket, checked_file,
                           os.path.join(local_dir, checked_file))
            print("checkpoint model at epoch {} saved!".format(epoch))

        print(
            "Invoking the next round of functions. round: {}/{}, start epoch: {}, run epoch: {}"
            .format(
                int((epoch + 1) / run_epochs) + 1,
                math.ceil(n_epochs / run_epochs), epoch + 1, run_epochs))
        lambda_client = boto3.client('lambda')
        payload = {
            'train_file': event['train_file'],
            'test_file': event['test_file'],
            'data_bucket': event['data_bucket'],
            'n_features': event['n_features'],
            'n_classes': event['n_classes'],
            'n_workers': event['n_workers'],
            'worker_index': event['worker_index'],
            'cp_bucket': event['cp_bucket'],
            'host': event['host'],
            'port': event['port'],
            'model': event['model'],
            'optim': event['optim'],
            'sync_mode': event['sync_mode'],
            'lr': event['lr'],
            'batch_size': event['batch_size'],
            'n_epochs': event['n_epochs'],
            'start_epoch': epoch + 1,
            'run_epochs': event['run_epochs'],
            'function_name': event['function_name']
        }
        lambda_client.invoke(FunctionName=function_name,
                             InvocationType='Event',
                             Payload=json.dumps(payload))

    end_time = time.time()
    print("Elapsed time = {} s".format(end_time - start_time))
Пример #5
0
def handler(event, context):
    start_time = time.time()
    worker_index = event['rank']
    num_workers = event['num_workers']
    host = event['host']
    port = event['port']
    size = event['size']

    print('number of workers = {}'.format(num_workers))
    print('worker index = {}'.format(worker_index))
    print("host = {}".format(host))
    print("port = {}".format(port))
    print("size = {}".format(size))

    # Set thrift connection
    # Make socket
    transport = TSocket.TSocket(host, port)
    # Buffering is critical. Raw sockets are very slow
    transport = TTransport.TBufferedTransport(transport)
    # Wrap in a protocol
    protocol = TBinaryProtocol.TBinaryProtocol(transport)
    # Create a client to use the protocol encoder
    t_client = ParameterServer.Client(protocol)
    # Connect!
    transport.open()

    # test thrift connection
    ps_client.ping(t_client)
    print("create and ping thrift server >>> HOST = {}, PORT = {}".format(
        host, port))

    # register model
    ps_client.register_model(t_client, worker_index, MODEL_NAME, size,
                             num_workers)
    ps_client.exist_model(t_client, MODEL_NAME)
    print("register and check model >>> name = {}, length = {}".format(
        MODEL_NAME, size))

    # Training the Model
    train_start = time.time()
    iter_counter = 0
    for epoch in range(NUM_EPOCHS):
        epoch_start = time.time()
        for batch_index in range(NUM_BATCHES):
            print("------worker {} epoch {} batch {}------".format(
                worker_index, epoch, batch_index))
            batch_start = time.time()

            loss = 0.0

            # pull latest model
            ps_client.can_pull(t_client, MODEL_NAME, iter_counter,
                               worker_index)
            pull_start = time.time()
            latest_model = ps_client.pull_model(t_client, MODEL_NAME,
                                                iter_counter, worker_index)
            pull_time = time.time() - pull_start

            w_b_grad = np.random.rand(1, size).astype(np.double).flatten()

            # push gradient to PS
            ps_client.can_push(t_client, MODEL_NAME, iter_counter,
                               worker_index)
            push_start = time.time()
            ps_client.push_grad(t_client, MODEL_NAME, w_b_grad, LEARNING_RATE,
                                iter_counter, worker_index)
            push_time = time.time() - push_start
            ps_client.can_pull(t_client, MODEL_NAME, iter_counter + 1,
                               worker_index)  # sync all workers

            print(
                'Epoch: [%d/%d], Step: [%d/%d] >>> Time: %.4f, Loss: %.4f, epoch cost %.4f, '
                'batch cost %.4f s: pull model cost %.4f s, push update cost %.4f s'
                % (epoch + 1, NUM_EPOCHS, batch_index, NUM_BATCHES,
                   time.time() - train_start, loss, time.time() - epoch_start,
                   time.time() - batch_start, pull_time, push_time))
            iter_counter += 1

    end_time = time.time()
    print("Elapsed time = {} s".format(end_time - start_time))
Пример #6
0
def handler(event, context):
    startTs = time.time()
    num_features = event['num_features']
    learning_rate = event["learning_rate"]
    batch_size = event["batch_size"]
    num_epochs = event["num_epochs"]
    validation_ratio = event["validation_ratio"]

    # Reading data from S3
    bucket_name = event['bucket_name']
    key = urllib.parse.unquote_plus(event['key'], encoding='utf-8')
    print(f"Reading training data from bucket = {bucket_name}, key = {key}")
    key_splits = key.split("_")
    worker_index = int(key_splits[0])
    num_worker = int(key_splits[1])

    # read file from s3
    file = get_object(bucket_name, key).read().decode('utf-8').split("\n")
    print("read data cost {} s".format(time.time() - startTs))

    parse_start = time.time()
    dataset = SparseDatasetWithLines(file, num_features)
    print("parse data cost {} s".format(time.time() - parse_start))

    preprocess_start = time.time()
    dataset_size = len(dataset)
    indices = list(range(dataset_size))
    split = int(np.floor(validation_ratio * dataset_size))
    np.random.seed(42)
    np.random.shuffle(indices)
    train_indices, val_indices = indices[split:], indices[:split]

    train_set = [dataset[i] for i in train_indices]
    val_set = [dataset[i] for i in val_indices]

    print("preprocess data cost {} s".format(time.time() - preprocess_start))

    # Set thrift connection
    # Make socket
    transport = TSocket.TSocket(constants.HOST, constants.PORT)
    # Buffering is critical. Raw sockets are very slow
    transport = TTransport.TBufferedTransport(transport)
    # Wrap in a protocol
    protocol = TBinaryProtocol.TBinaryProtocol(transport)
    # Create a client to use the protocol encoder
    t_client = ParameterServer.Client(protocol)
    # Connect!
    transport.open()

    # test thrift connection
    ps_client.ping(t_client)
    print("create and ping thrift server >>> HOST = {}, PORT = {}".format(
        constants.HOST, constants.PORT))

    svm = SparseSVM(train_set, val_set, num_features, num_epochs,
                    learning_rate, batch_size)

    # register model
    model_name = "w.b"
    model_length = num_features
    ps_client.register_model(t_client, worker_index, model_name, model_length,
                             num_worker)
    ps_client.exist_model(t_client, model_name)
    print("register and check model >>> name = {}, length = {}".format(
        model_name, model_length))

    # Training the Model
    train_start = time.time()
    iter_counter = 0

    # Training the Model
    for epoch in range(num_epochs):
        epoch_start = time.time()
        num_batches = math.floor(len(train_set) / batch_size)
        print(f"worker {worker_index} epoch {epoch}")

        for batch_idx in range(num_batches):
            batch_start = time.time()
            # pull latest model
            ps_client.can_pull(t_client, model_name, iter_counter,
                               worker_index)
            latest_model = ps_client.pull_model(t_client, model_name,
                                                iter_counter, worker_index)
            svm.weights = torch.from_numpy(latest_model).reshape(
                num_features, 1)

            batch_ins, batch_label = svm.next_batch(batch_idx)
            acc = svm.one_epoch(batch_idx, epoch)
            compute_end = time.time()

            sync_start = time.time()
            w_update = svm.weights - latest_model
            ps_client.can_push(t_client, model_name, iter_counter,
                               worker_index)
            ps_client.push_update(t_client, model_name, w_update,
                                  learning_rate, iter_counter, worker_index)
            ps_client.can_pull(t_client, model_name, iter_counter + 1,
                               worker_index)  # sync all workers
            sync_time = time.time() - sync_start

            print(
                'Epoch: [%d/%d], Step: [%d/%d] >>> Time: %.4f, train acc: %.4f, epoch cost %.4f, '
                'batch cost %.4f s: cal cost %.4f s and communication cost %.4f s'
                % (epoch + 1, NUM_EPOCHS, batch_idx + 1,
                   len(train_indices) / batch_size, time.time() - train_start,
                   acc, time.time() - epoch_start, time.time() - batch_start,
                   compute_end - batch_start, sync_time))
            iter_counter += 1

        val_acc = svm.evaluate()
        print("Epoch takes {}s, validation accuracy: {}".format(
            time.time() - epoch_start, val_acc))
Пример #7
0
def handler(event, context):
    start_time = time.time()
    bucket = event['data_bucket']
    worker_index = event['rank']
    num_worker = event['num_workers']
    key = event['key']

    print('bucket = {}'.format(bucket))
    print('number of workers = {}'.format(num_worker))
    print('worker index = {}'.format(worker_index))

    # Set thrift connection
    # Make socket
    transport = TSocket.TSocket(constants.HOST, constants.PORT)
    # Buffering is critical. Raw sockets are very slow
    transport = TTransport.TBufferedTransport(transport)
    # Wrap in a protocol
    protocol = TBinaryProtocol.TBinaryProtocol(transport)
    # Create a client to use the protocol encoder
    t_client = ParameterServer.Client(protocol)
    # Connect!
    transport.open()

    # test thrift connection
    ps_client.ping(t_client)
    print("create and ping thrift server >>> HOST = {}, PORT = {}".format(
        constants.HOST, constants.PORT))

    #bucket = "cifar10dataset"

    print('data_bucket = {}\n worker_index:{}\n num_worker:{}\n key:{}'.format(
        bucket, worker_index, num_worker, key))

    # read file from s3
    readS3_start = time.time()
    train_path = download_file(bucket, key)
    trainset = torch.load(train_path)
    test_path = download_file(bucket, test_file)
    testset = torch.load(test_path)

    print("read data cost {} s".format(time.time() - readS3_start))
    preprocess_start = time.time()
    batch_size = 200
    train_loader = torch.utils.data.DataLoader(trainset,
                                               batch_size=batch_size,
                                               shuffle=True)

    test_loader = torch.utils.data.DataLoader(testset,
                                              batch_size=batch_size,
                                              shuffle=False)
    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse',
               'ship', 'truck')
    device = 'cpu'
    print("preprocess data cost {} s".format(time.time() - preprocess_start))

    model = MobileNet()
    model = model.to(device)

    # Loss and Optimizer
    # Softmax is internally computed.
    # Set parameters to be updated.
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=learning_rate,
                                momentum=0.9,
                                weight_decay=5e-4)

    # register model
    model_name = "mobilenet"
    parameter_shape = []
    parameter_length = []
    model_length = 0
    for param in model.parameters():
        tmp_shape = 1
        parameter_shape.append(param.data.numpy().shape)
        for w in param.data.numpy().shape:
            tmp_shape *= w
        parameter_length.append(tmp_shape)
        model_length += tmp_shape

    ps_client.register_model(t_client, worker_index, model_name, model_length,
                             num_worker)
    ps_client.exist_model(t_client, model_name)
    print("register and check model >>> name = {}, length = {}".format(
        model_name, model_length))

    # Training the Model
    train_start = time.time()
    iter_counter = 0
    for epoch in range(num_epochs):
        epoch_start = time.time()
        model.train()
        for batch_index, (inputs, targets) in enumerate(train_loader):
            print("------worker {} epoch {} batch {}------".format(
                worker_index, epoch, batch_index))
            batch_start = time.time()

            # pull latest model
            ps_client.can_pull(t_client, model_name, iter_counter,
                               worker_index)
            latest_model = ps_client.pull_model(t_client, model_name,
                                                iter_counter, worker_index)
            pos = 0
            for layer_index, param in enumerate(model.parameters()):
                param.data = Variable(
                    torch.from_numpy(
                        np.asarray(latest_model[pos:pos +
                                                parameter_length[layer_index]],
                                   dtype=np.float32).reshape(
                                       parameter_shape[layer_index])))
                pos += parameter_length[layer_index]

            # Forward + Backward + Optimize
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            optimizer.zero_grad()
            loss.backward()

            # flatten and concat gradients of weight and bias
            param_grad = np.zeros((1))
            for param in model.parameters():
                #print("shape of layer = {}".format(param.data.numpy().flatten().shape))
                param_grad = np.concatenate(
                    (param_grad, param.data.numpy().flatten()))
            param_grad = np.delete(param_grad, 0)
            print("model_length = {}".format(param_grad.shape))

            # push gradient to PS
            sync_start = time.time()
            print(
                ps_client.can_push(t_client, model_name, iter_counter,
                                   worker_index))
            print(
                ps_client.push_grad(t_client, model_name, param_grad,
                                    learning_rate, iter_counter, worker_index))
            print(
                ps_client.can_pull(t_client, model_name, iter_counter + 1,
                                   worker_index))  # sync all workers
            sync_time = time.time() - sync_start

            print(
                'Epoch: [%d/%d], Step: [%d/%d] >>> Time: %.4f, Loss: %.4f, epoch cost %.4f, '
                'batch cost %.4f s: cal cost %.4f s and communication cost %.4f s'
                % (epoch + 1, num_epochs, batch_index + 1,
                   len(train_indices) / batch_size, time.time() - train_start,
                   loss.data, time.time() - epoch_start,
                   time.time() - batch_start, cal_time, sync_time))
            iter_counter += 1

            test(epoch, model, test_loader, criterion, device)
            optimizer.step()
Пример #8
0
def main():
    start_time = time.time()

    parser = argparse.ArgumentParser()
    parser.add_argument('--num-workers', type=int, default=1)
    parser.add_argument('--rank', type=int, default=0)
    parser.add_argument('--host', type=str, default=constants.HOST)
    parser.add_argument('--port', type=int, default=constants.PORT)
    parser.add_argument('--size', type=int, default=100)
    args = parser.parse_args()
    print(args)

    print("host = {}".format(args.host))
    print("port = {}".format(args.port))

    # Set thrift connection
    # Make socket
    transport = TSocket.TSocket(args.host, args.port)
    # Buffering is critical. Raw sockets are very slow
    transport = TTransport.TBufferedTransport(transport)
    # Wrap in a protocol
    protocol = TBinaryProtocol.TBinaryProtocol(transport)
    # Create a client to use the protocol encoder
    t_client = ParameterServer.Client(protocol)
    # Connect!
    transport.open()

    # test thrift connection
    ps_client.ping(t_client)
    print("create and ping thrift server >>> HOST = {}, PORT = {}".format(args.host, args.port))

    # register model
    ps_client.register_model(t_client, args.rank, MODEL_NAME, args.size, args.num_workers)
    ps_client.exist_model(t_client, MODEL_NAME)
    print("register and check model >>> name = {}, length = {}".format(MODEL_NAME, args.size))

    # Training the Model
    train_start = time.time()
    iter_counter = 0
    for epoch in range(NUM_EPOCHS):
        epoch_start = time.time()
        for batch_index in range(NUM_BATCHES):
            print("------worker {} epoch {} batch {}------"
                  .format(args.rank, epoch, batch_index))
            batch_start = time.time()

            loss = 0.0

            # pull latest model
            ps_client.can_pull(t_client, MODEL_NAME, iter_counter, args.rank)
            pull_start = time.time()
            latest_model = ps_client.pull_model(t_client, MODEL_NAME, iter_counter, args.rank)
            pull_time = time.time() - pull_start

            cal_start = time.time()
            w_b_grad = np.random.rand(1, args.size).astype(np.double).flatten()
            cal_time = time.time() - cal_start

            # push gradient to PS
            ps_client.can_push(t_client, MODEL_NAME, iter_counter, args.rank)
            push_start = time.time()
            ps_client.push_grad(t_client, MODEL_NAME, w_b_grad, LEARNING_RATE, iter_counter, args.rank)
            push_time = time.time() - push_start
            ps_client.can_pull(t_client, MODEL_NAME, iter_counter + 1, args.rank)  # sync all workers

            print('Epoch: [%d/%d], Step: [%d/%d] >>> Time: %.4f, Loss: %.4f, epoch cost %.4f, '
                  'batch cost %.4f s: cal cost %.4f s, pull model cost %.4f s, push update cost %.4f s'
                  % (epoch + 1, NUM_EPOCHS, batch_index, NUM_BATCHES,
                     time.time() - train_start, loss, time.time() - epoch_start,
                     time.time() - batch_start, cal_time, pull_time, push_time))
            iter_counter += 1

    end_time = time.time()
    print("Elapsed time = {} s".format(end_time - start_time))
Пример #9
0
def handler(event, context):
    start_time = time.time()

    # dataset setting
    file = event['file']
    data_bucket = event['data_bucket']
    dataset_type = event['dataset_type']
    assert dataset_type == "sparse_libsvm"
    n_features = event['n_features']
    n_classes = event['n_classes']
    n_workers = event['n_workers']
    worker_index = event['worker_index']

    # ps setting
    host = event['host']
    port = event['port']

    # training setting
    model_name = event['model']
    optim = event['optim']
    sync_mode = event['sync_mode']
    assert model_name.lower() in MLModel.Sparse_Linear_Models
    assert optim.lower() == Optimization.Grad_Avg
    assert sync_mode.lower() == Synchronization.Reduce

    # hyper-parameter
    learning_rate = event['lr']
    batch_size = event['batch_size']
    n_epochs = event['n_epochs']
    valid_ratio = event['valid_ratio']

    print('bucket = {}'.format(data_bucket))
    print("file = {}".format(file))
    print('number of workers = {}'.format(n_workers))
    print('worker index = {}'.format(worker_index))
    print('model = {}'.format(model_name))
    print('host = {}'.format(host))
    print('port = {}'.format(port))

    # Set thrift connection
    # Make socket
    transport = TSocket.TSocket(host, port)
    # Buffering is critical. Raw sockets are very slow
    transport = TTransport.TBufferedTransport(transport)
    # Wrap in a protocol
    protocol = TBinaryProtocol.TBinaryProtocol(transport)
    # Create a client to use the protocol encoder
    t_client = ParameterServer.Client(protocol)
    # Connect!
    transport.open()
    # test thrift connection
    ps_client.ping(t_client)
    print("create and ping thrift server >>> HOST = {}, PORT = {}".format(
        host, port))

    # Read file from s3
    read_start = time.time()
    storage = S3Storage()
    lines = storage.load(file, data_bucket).read().decode('utf-8').split("\n")
    print("read data cost {} s".format(time.time() - read_start))

    parse_start = time.time()
    dataset = libsvm_dataset.from_lines(lines, n_features, dataset_type)
    print("parse data cost {} s".format(time.time() - parse_start))

    preprocess_start = time.time()
    # Creating data indices for training and validation splits:
    dataset_size = len(dataset)
    indices = list(range(dataset_size))
    split = int(np.floor(valid_ratio * dataset_size))

    shuffle_dataset = True
    random_seed = 100
    if shuffle_dataset:
        np.random.seed(random_seed)
        np.random.shuffle(indices)
    train_indices, val_indices = indices[split:], indices[:split]

    # split train set and test set
    train_set = [dataset[i] for i in train_indices]
    n_train_batch = math.floor(len(train_set) / batch_size)
    val_set = [dataset[i] for i in val_indices]
    print("preprocess data cost {} s, dataset size = {}".format(
        time.time() - preprocess_start, dataset_size))

    model = linear_models.get_sparse_model(model_name, train_set, val_set,
                                           n_features, n_epochs, learning_rate,
                                           batch_size)

    # register model
    model_name = "w.b"
    weight_length = n_features
    bias_length = 1
    model_length = weight_length + bias_length
    ps_client.register_model(t_client, worker_index, model_name, model_length,
                             n_workers)
    ps_client.exist_model(t_client, model_name)
    print("register and check model >>> name = {}, length = {}".format(
        model_name, model_length))

    # Training the Model
    train_start = time.time()
    iter_counter = 0
    for epoch in range(n_epochs):
        epoch_start = time.time()
        epoch_cal_time = 0
        epoch_comm_time = 0
        epoch_loss = 0.

        for batch_idx in range(n_train_batch):
            batch_start = time.time()
            batch_comm_time = 0

            # pull latest model
            ps_client.can_pull(t_client, model_name, iter_counter,
                               worker_index)
            latest_model = ps_client.pull_model(t_client, model_name,
                                                iter_counter, worker_index)
            model.weight = torch.from_numpy(
                np.asarray(latest_model[:weight_length]).astype(
                    np.float32).reshape(n_features, 1))
            model.bias = float(latest_model[-1])
            batch_comm_time += time.time() - batch_start

            batch_loss, batch_acc = model.one_batch()
            epoch_loss += batch_loss.average

            w_b = np.concatenate((model.weight.double().numpy().flatten(),
                                  np.array([model.bias]).astype(np.double)))
            w_b_update = np.subtract(w_b, latest_model)
            batch_cal_time = time.time() - batch_start

            # push gradient to PS
            batch_comm_start = time.time()
            ps_client.can_push(t_client, model_name, iter_counter,
                               worker_index)
            ps_client.push_grad(t_client, model_name, w_b_update,
                                1.0 / n_workers, iter_counter, worker_index)
            ps_client.can_pull(t_client, model_name, iter_counter + 1,
                               worker_index)  # sync all workers
            batch_comm_time += time.time() - batch_comm_start

            epoch_cal_time += batch_cal_time
            epoch_comm_time += batch_comm_time

            if batch_idx % 10 == 0:
                print(
                    'Epoch: [%d/%d], Batch: [%d/%d], Time: %.4f, Loss: %.4f, Accuracy: %.4f,'
                    'batch cost %.4f s: cal cost %.4f s and communication cost %.4f s'
                    % (epoch + 1, n_epochs, batch_idx + 1, n_train_batch,
                       time.time() - train_start, batch_loss.average,
                       batch_acc.accuracy, time.time() - batch_start,
                       batch_cal_time, batch_comm_time))

            iter_counter += 1

        # Test the Model
        test_start = time.time()
        test_loss, test_acc = model.evaluate()
        test_time = time.time() - test_start

        print(
            "Epoch: [{}/{}] finishes, Batch: [{}/{}], Time: {:.4f}, Loss: {:.4f}, epoch cost {:.4f} s, "
            "calculation cost = {:.4f} s, synchronization cost {:.4f} s, test cost {:.4f} s, "
            "accuracy of the model on the {} test samples: {}, loss = {}".
            format(epoch + 1, n_epochs, batch_idx + 1, n_train_batch,
                   time.time() - train_start, epoch_loss,
                   time.time() - epoch_start, epoch_cal_time, epoch_comm_time,
                   test_time, len(val_set), test_acc.accuracy,
                   test_loss.average))

    end_time = time.time()
    print("Elapsed time = {} s".format(end_time - start_time))
Пример #10
0
def handler(event, context):
    startTs = time.time()
    num_features = event['num_features']
    learning_rate = event["learning_rate"]
    batch_size = event["batch_size"]
    num_epochs = event["num_epochs"]
    validation_ratio = event["validation_ratio"]

    # Reading data from S3
    bucket_name = event['bucket_name']
    key = urllib.parse.unquote_plus(event['key'], encoding='utf-8')
    print(f"Reading training data from bucket = {bucket_name}, key = {key}")
    key_splits = key.split("_")
    worker_index = int(key_splits[0])
    num_worker = int(key_splits[1])

    # read file from s3
    file = get_object(bucket_name, key).read().decode('utf-8').split("\n")
    print("read data cost {} s".format(time.time() - startTs))

    parse_start = time.time()
    dataset = SparseDatasetWithLines(file, num_features)
    print("parse data cost {} s".format(time.time() - parse_start))

    preprocess_start = time.time()
    dataset_size = len(dataset)
    indices = list(range(dataset_size))
    split = int(np.floor(validation_ratio * dataset_size))
    np.random.seed(42)
    np.random.shuffle(indices)
    train_indices, val_indices = indices[split:], indices[:split]

    train_set = [dataset[i] for i in train_indices]
    val_set = [dataset[i] for i in val_indices]

    print("preprocess data cost {} s".format(time.time() - preprocess_start))

    # Set thrift connection
    # Make socket
    transport = TSocket.TSocket(constants.HOST, constants.PORT)
    # Buffering is critical. Raw sockets are very slow
    transport = TTransport.TBufferedTransport(transport)
    # Wrap in a protocol
    protocol = TBinaryProtocol.TBinaryProtocol(transport)
    # Create a client to use the protocol encoder
    t_client = ParameterServer.Client(protocol)
    # Connect!
    transport.open()

    # test thrift connection
    ps_client.ping(t_client)
    print("create and ping thrift server >>> HOST = {}, PORT = {}".format(
        constants.HOST, constants.PORT))

    # register model
    model_name = "w.b"
    model_length = num_features + 1
    ps_client.register_model(t_client, worker_index, model_name, model_length,
                             num_worker)
    ps_client.exist_model(t_client, model_name)
    print("register and check model >>> name = {}, length = {}".format(
        model_name, model_length))

    # Training the Model
    train_start = time.time()
    iter_counter = 0
    lr = LogisticRegression(train_set, val_set, num_features, num_epochs,
                            learning_rate, batch_size)
    # Training the Model
    for epoch in range(num_epochs):
        epoch_start = time.time()
        num_batches = math.floor(len(train_set) / batch_size)
        print(f"worker {worker_index} epoch {epoch}")

        for batch_idx in range(num_batches):
            batch_start = time.time()
            # pull latest model
            ps_client.can_pull(t_client, model_name, iter_counter,
                               worker_index)
            latest_model = ps_client.pull_model(t_client, model_name,
                                                iter_counter, worker_index)
            lr.grad = torch.from_numpy(np.asarray(latest_model[:-1])).reshape(
                num_features, 1)
            lr.bias = float(latest_model[-1])

            compute_start = time.time()
            batch_ins, batch_label = lr.next_batch(batch_idx)
            batch_grad = torch.zeros(lr.n_input, 1, requires_grad=False)
            batch_bias = np.float(0)
            train_loss = Loss()
            train_acc = Accuracy()

            for i in range(len(batch_ins)):
                z = lr.forward(batch_ins[i])
                h = lr.sigmoid(z)
                loss = lr.loss(h, batch_label[i])
                # print("z= {}, h= {}, loss = {}".format(z, h, loss))
                train_loss.update(loss, 1)
                train_acc.update(h, batch_label[i])
                g = lr.backward(batch_ins[i], h.item(), batch_label[i])
                batch_grad.add_(g)
                batch_bias += np.sum(h.item() - batch_label[i])
            batch_grad = batch_grad.div(len(batch_ins))
            batch_bias = batch_bias / len(batch_ins)
            batch_grad.mul_(-1.0 * learning_rate)
            lr.grad.add_(batch_grad)
            lr.bias = lr.bias - batch_bias * learning_rate

            np_grad = lr.grad.numpy().flatten()
            w_b_grad = np.append(np_grad, lr.bias)
            compute_end = time.time()

            sync_start = time.time()
            ps_client.can_push(t_client, model_name, iter_counter,
                               worker_index)
            ps_client.push_grad(t_client, model_name, w_b_grad, learning_rate,
                                iter_counter, worker_index)
            ps_client.can_pull(t_client, model_name, iter_counter + 1,
                               worker_index)  # sync all workers
            sync_time = time.time() - sync_start

            print(
                'Epoch: [%d/%d], Step: [%d/%d] >>> Time: %.4f, Loss: %.4f, epoch cost %.4f, '
                'batch cost %.4f s: cal cost %.4f s and communication cost %.4f s'
                % (epoch + 1, NUM_EPOCHS, batch_idx + 1, len(train_indices) /
                   batch_size, time.time() - train_start, train_loss,
                   time.time() - epoch_start, time.time() - batch_start,
                   compute_end - batch_start, sync_time))
            iter_counter += 1

        val_loss, val_acc = lr.evaluate()
        print(f"Validation loss: {val_loss}, validation accuracy: {val_acc}")
        print(f"Epoch takes {time.time() - epoch_start}s")